One of the data processing step in my model uses a FFT and/or IFFT to an arbitrary tensor.
Things works nicely as long as I kept the dimension of the tensor small.
But, once it gets to a certain size, FFT and IFFT ran on GPU won’t spit out values similar to CPU.
a = torch.load('H_fft_2000.pt')
b = a.clone().cuda()
print(f'a.shape : {a.shape}')
print(f'b.shape : {b.shape}')
print(f'a.device : {a.device}')
print(f'a.real Max(): {a.real.max()}, Min(): {a.real.min()}')
print(f'a.imag Max(): {a.imag.max()}, Min(): {a.imag.min()}')
print(f'a.abs() Max(): {a.abs().max()}, Min(): {a.abs().min()}')
print(f'b.device : {b.device}')
print(f'b.real Max(): {b.real.max()}, Min(): {b.real.min()}')
print(f'b.imag Max(): {b.imag.max()}, Min(): {b.imag.min()}')
print(f'b.abs() Max(): {b.abs().max()}, Min(): {b.abs().min()}')
a_fft = torch.fft.fftn(a, dim=[-2,-1], norm="ortho")
b_fft = torch.fft.fftn(b, dim=[-2,-1], norm="ortho")
print(f'a_fft.device : {a_fft.device}')
print(f'a_fft.real Max(): {a_fft.real.max()}, Min(): {a_fft.real.min()}')
print(f'a_fft.imag Max(): {a_fft.imag.max()}, Min(): {a_fft.imag.min()}')
print(f'a_fft.abs() Max(): {a_fft.abs().max()}, Min(): {a_fft.abs().min()}')
print(f'b_fft.device : {b_fft.device}')
print(f'b_fft.real Max(): {b_fft.real.max()}, Min(): {b_fft.real.min()}')
print(f'b_fft.imag Max(): {b_fft.imag.max()}, Min(): {b_fft.imag.min()}')
print(f'b_fft.abs() Max(): {b_fft.abs().max()}, Min(): {b_fft.abs().min()}')
a_ifft = torch.fft.ifftn(a_fft, dim=[-2,-1], norm="ortho")
b_ifft = torch.fft.ifftn(b_fft, dim=[-2,-1], norm="ortho")
print(f'a_ifft.device : {a_ifft.device}')
print(f'a_ifft.real Max(): {a_ifft.real.max()}, Min(): {a_ifft.real.min()}')
print(f'a_ifft.imag Max(): {a_ifft.imag.max()}, Min(): {a_ifft.imag.min()}')
print(f'a_ifft.abs() Max(): {a_ifft.abs().max()}, Min(): {a_ifft.abs().min()}')
print(f'b_ifft.device : {b_ifft.device}')
print(f'b_ifft.real Max(): {b_ifft.real.max()}, Min(): {b_ifft.real.min()}')
print(f'b_ifft.imag Max(): {b_ifft.imag.max()}, Min(): {b_ifft.imag.min()}')
print(f'b_ifft.abs() Max(): {b_ifft.abs().max()}, Min(): {b_ifft.abs().min()}')
Running the code abover with an arbitrary tensor (H_fft_2000.pt) give me OK results regardless of the normalization mode if torch.fft.fftn/ifftn().
a.shape : torch.Size([3999, 3999])
b.shape : torch.Size([3999, 3999])
a.device : cpu
a.real Max(): 1227899666432.0, Min(): -1286383861760.0
a.imag Max(): 1345065058304.0, Min(): -1391184314368.0
a.abs() Max(): 1392092512256.0, Min(): 3149313.0
b.device : cuda:0
b.real Max(): 1227899666432.0, Min(): -1286383861760.0
b.imag Max(): 1345065058304.0, Min(): -1391184314368.0
b.abs() Max(): 1392092512256.0, Min(): 3149313.0
a_fft.device : cpu
a_fft.real Max(): 63194193920.0, Min(): -63188475904.0
a_fft.imag Max(): 63191957504.0, Min(): -63188336640.0
a_fft.abs() Max(): 63195369472.0, Min(): 63144849408.0
b_fft.device : cuda:0
b_fft.real Max(): 63194198016.0, Min(): -63188475904.0
b_fft.imag Max(): 63191961600.0, Min(): -63188340736.0
b_fft.abs() Max(): 63195381760.0, Min(): 63144853504.0
a_ifft.device : cpu
a_ifft.real Max(): 1227899928576.0, Min(): -1286384123904.0
a_ifft.imag Max(): 1345065189376.0, Min(): -1391184576512.0
a_ifft.abs() Max(): 1392092774400.0, Min(): 3148201.0
b_ifft.device : cuda:0
b_ifft.real Max(): 1227899928576.0, Min(): -1286384123904.0
b_ifft.imag Max(): 1345065189376.0, Min(): -1391184445440.0
b_ifft.abs() Max(): 1392092774400.0, Min(): 3139220.5
However, once the input tensor gets larger… (H_fft_2000.pt → H_fft_3000.pt)
a = torch.load('H_fft_3000.pt')
b = a.clone().cuda()
...
a_fft = torch.fft.fftn(a, dim=[-2,-1], norm="forward")
b_fft = torch.fft.fftn(b, dim=[-2,-1], norm="forward")
...
a_ifft = torch.fft.ifftn(a_fft, dim=[-2,-1], norm="forward")
b_ifft = torch.fft.ifftn(b_fft, dim=[-2,-1], norm="forward")
...
Running the code abover with an arbitrary tensor (H_fft_3000.pt), and a “forward” normalization won’t give me OK results on both torch.fft.fftn/ifftn().
a.shape : torch.Size([5999, 5999])
b.shape : torch.Size([5999, 5999])
a.device : cpu
a.real Max(): 1317886492672.0, Min(): -1305796542464.0
a.imag Max(): 1357465518080.0, Min(): -1370446233600.0
a.abs() Max(): 1387183996928.0, Min(): 554983.6875
b.device : cuda:0
b.real Max(): 1317886492672.0, Min(): -1305796542464.0
b.imag Max(): 1357465518080.0, Min(): -1370446233600.0
b.abs() Max(): 1387183996928.0, Min(): 554983.6875
a_fft.device : cpu
a_fft.real Max(): 15802503.0, Min(): -15801069.0
a_fft.imag Max(): 15801941.0, Min(): -15801037.0
a_fft.abs() Max(): 15802796.0, Min(): 15774392.0
b_fft.device : cuda:0
b_fft.real Max(): 38545.73828125, Min(): 0.015421354211866856
b_fft.imag Max(): 0.0, Min(): 0.0
b_fft.abs() Max(): 38545.73828125, Min(): 0.015421354211866856
a_ifft.device : cpu
a_ifft.real Max(): 1317886230528.0, Min(): -1305796411392.0
a_ifft.imag Max(): 1357465124864.0, Min(): -1370445971456.0
a_ifft.abs() Max(): 1387183603712.0, Min(): 543693.625
b_ifft.device : cuda:0
b_ifft.real Max(): 38545.73828125, Min(): 0.015421354211866856
b_ifft.imag Max(): 0.0, Min(): 0.0
b_ifft.abs() Max(): 38545.73828125, Min(): 0.015421354211866856
If I change the normalization to “ortho”, then the error only occurs for ifftn().
a = torch.load('H_fft_3000.pt')
b = a.clone().cuda()
...
a_fft = torch.fft.fftn(a, dim=[-2,-1], norm="ortho")
b_fft = torch.fft.fftn(b, dim=[-2,-1], norm="ortho")
...
a_ifft = torch.fft.ifftn(a_fft, dim=[-2,-1], norm="ortho")
b_ifft = torch.fft.ifftn(b_fft, dim=[-2,-1], norm="ortho")
...
The result of the code above is…
a.shape : torch.Size([5999, 5999])
b.shape : torch.Size([5999, 5999])
a.device : cpu
a.real Max(): 1317886492672.0, Min(): -1305796542464.0
a.imag Max(): 1357465518080.0, Min(): -1370446233600.0
a.abs() Max(): 1387183996928.0, Min(): 554983.6875
b.device : cuda:0
b.real Max(): 1317886492672.0, Min(): -1305796542464.0
b.imag Max(): 1357465518080.0, Min(): -1370446233600.0
b.abs() Max(): 1387183996928.0, Min(): 554983.6875
a_fft.device : cpu
a_fft.real Max(): 94799216640.0, Min(): -94790606848.0
a_fft.imag Max(): 94795833344.0, Min(): -94790393856.0
a_fft.abs() Max(): 94800961536.0, Min(): 94630592512.0
b_fft.device : cuda:0
b_fft.real Max(): 94799273984.0, Min(): -94790639616.0
b_fft.imag Max(): 94795882496.0, Min(): -94790459392.0
b_fft.abs() Max(): 94801018880.0, Min(): 94630559744.0
a_ifft.device : cpu
a_ifft.real Max(): 1317886099456.0, Min(): -1305796280320.0
a_ifft.imag Max(): 1357464993792.0, Min(): -1370445840384.0
a_ifft.abs() Max(): 1387183341568.0, Min(): 552186.8125
b_ifft.device : cuda:0
b_ifft.real Max(): 15802803.0, Min(): 15774389.0
b_ifft.imag Max(): 0.0, Min(): 0.0
b_ifft.abs() Max(): 15802803.0, Min(): 15774389.0
What could be the cause of this problem??
I suspect some sort of over/underflow during fft/ifft, but I also have doubt considering the same precision on GPU and CPU (torch.complex64 for complex numbers, torch.float32 for float numbers), and no warning or exception during fft/ifft.
Any advices??