Autograd runtime error with operations involving ffts

I am working with FFTs with autograd turned on. As I understand, the new pytorch.fft module does not yet support all autograd operations so I am sticking to the old API. I have incorporated some of the suggestions I could find in other related discussion topics but I might have missed something.

I am turning gradients on for an array in Fourier space and doing a bunch of operations including forward and reverse FFTs. The actual code uses 3D FFTs and the update is more complex. Here I show a 1D example that throws no errors and a 2D example throws an error similar to my 3D code:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

1D example (works)

IFFT_SIZE = (8, )

x_k = torch.rfft(torch.randn(IFFT_SIZE), 1)

# switch on grad and clone
x_k.requires_grad_(True)
x_k_clone = x_k.clone()

# inverse fft
x_r = torch.irfft(x_k_clone, 1, signal_sizes=IFFT_SIZE)

# define some update
dx = 2.0 * torch.rfft(x_r, 1) 

# update
x_k_clone = x_k_clone + dx 

# inverse fft
x_r = torch.irfft(x_k_clone, 1, signal_sizes=IFFT_SIZE)

(x_r * x_r).sum().backward()

2D example (does not work)

IFFT_SIZE = (8, 8)

x_r = torch.empty(2, *IFFT_SIZE)
x_k = torch.rfft(torch.randn(2, *IFFT_SIZE), 2)
dx = torch.zeros_like(x_k)

# switch on grad and clone
x_k.requires_grad_(True)
x_k_clone = x_k.clone()

# inverse fft
for i in range(2):
    x_r[i] = torch.irfft(x_k_clone[i], 2, signal_sizes=IFFT_SIZE)

# define some update
dx = 2.0 * torch.stack([torch.rfft(x_r[0], 2), torch.rfft(x_r[1], 2)], dim=0)

# update
x_k_clone = x_k_clone + dx

# inverse fft
for i in range(2):
    x_r[i] = torch.irfft(x_k_clone[i], 2, signal_sizes=IFFT_SIZE)

(x_r * x_r).sum().backward()

With torch.autograd.set_detect_anomaly(True), it points to this line as the source of the error
dx = 2.0 * torch.stack([torch.rfft(x_r[0], 2), torch.rfft(x_r[1], 2)], dim=0)

I don’t understand how this is an in-place operation. Any suggestions? Thanks!

Fixed it. Key takeaways in fixing autograd problems like these:

  1. Use torch.autograd.set_detect_anomaly(True) to isolate the line with the problem.
  2. Call .clone() on the variables on the right hand side of the update.

had to replace the line:

dx = 2.0 * torch.stack([torch.rfft(x_r[0], 2), torch.rfft(x_r[1], 2)], dim=0)

with:

dx = 2.0 * torch.stack([torch.rfft(x_r[0].clone(), 2), torch.rfft(x_r[1].clone(), 2)], dim=0)

Found the following threads / blogs useful:

2 Likes