Different torch.fft.fftn/ifftn() result from CPU and GPU?

One of the data processing step in my model uses a FFT and/or IFFT to an arbitrary tensor.
Things works nicely as long as I kept the dimension of the tensor small.
But, once it gets to a certain size, FFT and IFFT ran on GPU won’t spit out values similar to CPU.

a = torch.load('H_fft_2000.pt')
b = a.clone().cuda()
print(f'a.shape : {a.shape}')
print(f'b.shape : {b.shape}')

print(f'a.device : {a.device}')
print(f'a.real  Max(): {a.real.max()}, Min(): {a.real.min()}')
print(f'a.imag  Max(): {a.imag.max()}, Min(): {a.imag.min()}')
print(f'a.abs() Max(): {a.abs().max()}, Min(): {a.abs().min()}')
print(f'b.device : {b.device}')
print(f'b.real  Max(): {b.real.max()}, Min(): {b.real.min()}')
print(f'b.imag  Max(): {b.imag.max()}, Min(): {b.imag.min()}')
print(f'b.abs() Max(): {b.abs().max()}, Min(): {b.abs().min()}')

a_fft = torch.fft.fftn(a, dim=[-2,-1], norm="ortho")
b_fft = torch.fft.fftn(b, dim=[-2,-1], norm="ortho")

print(f'a_fft.device : {a_fft.device}')
print(f'a_fft.real  Max(): {a_fft.real.max()}, Min(): {a_fft.real.min()}')
print(f'a_fft.imag  Max(): {a_fft.imag.max()}, Min(): {a_fft.imag.min()}')
print(f'a_fft.abs() Max(): {a_fft.abs().max()}, Min(): {a_fft.abs().min()}')
print(f'b_fft.device : {b_fft.device}')
print(f'b_fft.real  Max(): {b_fft.real.max()}, Min(): {b_fft.real.min()}')
print(f'b_fft.imag  Max(): {b_fft.imag.max()}, Min(): {b_fft.imag.min()}')
print(f'b_fft.abs() Max(): {b_fft.abs().max()}, Min(): {b_fft.abs().min()}')

a_ifft = torch.fft.ifftn(a_fft, dim=[-2,-1], norm="ortho")
b_ifft = torch.fft.ifftn(b_fft, dim=[-2,-1], norm="ortho")

print(f'a_ifft.device : {a_ifft.device}')
print(f'a_ifft.real  Max(): {a_ifft.real.max()}, Min(): {a_ifft.real.min()}')
print(f'a_ifft.imag  Max(): {a_ifft.imag.max()}, Min(): {a_ifft.imag.min()}')
print(f'a_ifft.abs() Max(): {a_ifft.abs().max()}, Min(): {a_ifft.abs().min()}')
print(f'b_ifft.device : {b_ifft.device}')
print(f'b_ifft.real  Max(): {b_ifft.real.max()}, Min(): {b_ifft.real.min()}')
print(f'b_ifft.imag  Max(): {b_ifft.imag.max()}, Min(): {b_ifft.imag.min()}')
print(f'b_ifft.abs() Max(): {b_ifft.abs().max()}, Min(): {b_ifft.abs().min()}')

Running the code abover with an arbitrary tensor (H_fft_2000.pt) give me OK results regardless of the normalization mode if torch.fft.fftn/ifftn().

a.shape : torch.Size([3999, 3999])
b.shape : torch.Size([3999, 3999])
a.device : cpu
a.real  Max(): 1227899666432.0, Min(): -1286383861760.0
a.imag  Max(): 1345065058304.0, Min(): -1391184314368.0
a.abs() Max(): 1392092512256.0, Min(): 3149313.0
b.device : cuda:0
b.real  Max(): 1227899666432.0, Min(): -1286383861760.0
b.imag  Max(): 1345065058304.0, Min(): -1391184314368.0
b.abs() Max(): 1392092512256.0, Min(): 3149313.0
a_fft.device : cpu
a_fft.real  Max(): 63194193920.0, Min(): -63188475904.0
a_fft.imag  Max(): 63191957504.0, Min(): -63188336640.0
a_fft.abs() Max(): 63195369472.0, Min(): 63144849408.0
b_fft.device : cuda:0
b_fft.real  Max(): 63194198016.0, Min(): -63188475904.0
b_fft.imag  Max(): 63191961600.0, Min(): -63188340736.0
b_fft.abs() Max(): 63195381760.0, Min(): 63144853504.0
a_ifft.device : cpu
a_ifft.real  Max(): 1227899928576.0, Min(): -1286384123904.0
a_ifft.imag  Max(): 1345065189376.0, Min(): -1391184576512.0
a_ifft.abs() Max(): 1392092774400.0, Min(): 3148201.0
b_ifft.device : cuda:0
b_ifft.real  Max(): 1227899928576.0, Min(): -1286384123904.0
b_ifft.imag  Max(): 1345065189376.0, Min(): -1391184445440.0
b_ifft.abs() Max(): 1392092774400.0, Min(): 3139220.5

However, once the input tensor gets larger… (H_fft_2000.pt → H_fft_3000.pt)

a = torch.load('H_fft_3000.pt')
b = a.clone().cuda()

...

a_fft = torch.fft.fftn(a, dim=[-2,-1], norm="forward")
b_fft = torch.fft.fftn(b, dim=[-2,-1], norm="forward")

...

a_ifft = torch.fft.ifftn(a_fft, dim=[-2,-1], norm="forward")
b_ifft = torch.fft.ifftn(b_fft, dim=[-2,-1], norm="forward")

...

Running the code abover with an arbitrary tensor (H_fft_3000.pt), and a “forward” normalization won’t give me OK results on both torch.fft.fftn/ifftn().

a.shape : torch.Size([5999, 5999])
b.shape : torch.Size([5999, 5999])
a.device : cpu
a.real  Max(): 1317886492672.0, Min(): -1305796542464.0
a.imag  Max(): 1357465518080.0, Min(): -1370446233600.0
a.abs() Max(): 1387183996928.0, Min(): 554983.6875
b.device : cuda:0
b.real  Max(): 1317886492672.0, Min(): -1305796542464.0
b.imag  Max(): 1357465518080.0, Min(): -1370446233600.0
b.abs() Max(): 1387183996928.0, Min(): 554983.6875
a_fft.device : cpu
a_fft.real  Max(): 15802503.0, Min(): -15801069.0
a_fft.imag  Max(): 15801941.0, Min(): -15801037.0
a_fft.abs() Max(): 15802796.0, Min(): 15774392.0
b_fft.device : cuda:0
b_fft.real  Max(): 38545.73828125, Min(): 0.015421354211866856
b_fft.imag  Max(): 0.0, Min(): 0.0
b_fft.abs() Max(): 38545.73828125, Min(): 0.015421354211866856
a_ifft.device : cpu
a_ifft.real  Max(): 1317886230528.0, Min(): -1305796411392.0
a_ifft.imag  Max(): 1357465124864.0, Min(): -1370445971456.0
a_ifft.abs() Max(): 1387183603712.0, Min(): 543693.625
b_ifft.device : cuda:0
b_ifft.real  Max(): 38545.73828125, Min(): 0.015421354211866856
b_ifft.imag  Max(): 0.0, Min(): 0.0
b_ifft.abs() Max(): 38545.73828125, Min(): 0.015421354211866856

If I change the normalization to “ortho”, then the error only occurs for ifftn().

a = torch.load('H_fft_3000.pt')
b = a.clone().cuda()

...

a_fft = torch.fft.fftn(a, dim=[-2,-1], norm="ortho")
b_fft = torch.fft.fftn(b, dim=[-2,-1], norm="ortho")

...

a_ifft = torch.fft.ifftn(a_fft, dim=[-2,-1], norm="ortho")
b_ifft = torch.fft.ifftn(b_fft, dim=[-2,-1], norm="ortho")

...

The result of the code above is…

a.shape : torch.Size([5999, 5999])
b.shape : torch.Size([5999, 5999])
a.device : cpu
a.real  Max(): 1317886492672.0, Min(): -1305796542464.0
a.imag  Max(): 1357465518080.0, Min(): -1370446233600.0
a.abs() Max(): 1387183996928.0, Min(): 554983.6875
b.device : cuda:0
b.real  Max(): 1317886492672.0, Min(): -1305796542464.0
b.imag  Max(): 1357465518080.0, Min(): -1370446233600.0
b.abs() Max(): 1387183996928.0, Min(): 554983.6875
a_fft.device : cpu
a_fft.real  Max(): 94799216640.0, Min(): -94790606848.0
a_fft.imag  Max(): 94795833344.0, Min(): -94790393856.0
a_fft.abs() Max(): 94800961536.0, Min(): 94630592512.0
b_fft.device : cuda:0
b_fft.real  Max(): 94799273984.0, Min(): -94790639616.0
b_fft.imag  Max(): 94795882496.0, Min(): -94790459392.0
b_fft.abs() Max(): 94801018880.0, Min(): 94630559744.0
a_ifft.device : cpu
a_ifft.real  Max(): 1317886099456.0, Min(): -1305796280320.0
a_ifft.imag  Max(): 1357464993792.0, Min(): -1370445840384.0
a_ifft.abs() Max(): 1387183341568.0, Min(): 552186.8125
b_ifft.device : cuda:0
b_ifft.real  Max(): 15802803.0, Min(): 15774389.0
b_ifft.imag  Max(): 0.0, Min(): 0.0
b_ifft.abs() Max(): 15802803.0, Min(): 15774389.0

What could be the cause of this problem??
I suspect some sort of over/underflow during fft/ifft, but I also have doubt considering the same precision on GPU and CPU (torch.complex64 for complex numbers, torch.float32 for float numbers), and no warning or exception during fft/ifft.

Any advices??


python: 3.6.9
torch: 1.8.1
cuda: 11.1 (with RTX2060 and GTX1060, meaning no amp nor tf32)

Hi,

Thanks for the report. Could you open an issue on github with your scripts to reproduce this please? That might be an issue on our end or a third party lib we’re using.

Just opend an issue (as a bug report) on github.
Please have a look and let me know if I made any mistakes using the library.