I am working with FFTs with autograd turned on. As I understand, the new pytorch.fft
module does not yet support all autograd operations so I am sticking to the old API. I have incorporated some of the suggestions I could find in other related discussion topics but I might have missed something.
I am turning gradients on for an array in Fourier space and doing a bunch of operations including forward and reverse FFTs. The actual code uses 3D FFTs and the update is more complex. Here I show a 1D example that throws no errors and a 2D example throws an error similar to my 3D code:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:
1D example (works)
IFFT_SIZE = (8, )
x_k = torch.rfft(torch.randn(IFFT_SIZE), 1)
# switch on grad and clone
x_k.requires_grad_(True)
x_k_clone = x_k.clone()
# inverse fft
x_r = torch.irfft(x_k_clone, 1, signal_sizes=IFFT_SIZE)
# define some update
dx = 2.0 * torch.rfft(x_r, 1)
# update
x_k_clone = x_k_clone + dx
# inverse fft
x_r = torch.irfft(x_k_clone, 1, signal_sizes=IFFT_SIZE)
(x_r * x_r).sum().backward()
2D example (does not work)
IFFT_SIZE = (8, 8)
x_r = torch.empty(2, *IFFT_SIZE)
x_k = torch.rfft(torch.randn(2, *IFFT_SIZE), 2)
dx = torch.zeros_like(x_k)
# switch on grad and clone
x_k.requires_grad_(True)
x_k_clone = x_k.clone()
# inverse fft
for i in range(2):
x_r[i] = torch.irfft(x_k_clone[i], 2, signal_sizes=IFFT_SIZE)
# define some update
dx = 2.0 * torch.stack([torch.rfft(x_r[0], 2), torch.rfft(x_r[1], 2)], dim=0)
# update
x_k_clone = x_k_clone + dx
# inverse fft
for i in range(2):
x_r[i] = torch.irfft(x_k_clone[i], 2, signal_sizes=IFFT_SIZE)
(x_r * x_r).sum().backward()
With torch.autograd.set_detect_anomaly(True)
, it points to this line as the source of the error
dx = 2.0 * torch.stack([torch.rfft(x_r[0], 2), torch.rfft(x_r[1], 2)], dim=0)
I don’t understand how this is an in-place operation. Any suggestions? Thanks!