Does torch.fft.irfft2 break gradient tracking in PyTorch when wrapped into a nn.Module?

I have a PyTorch model with a custom forward pass that involves applying torch.fft.irfft2 to the real component of a complex input tensor.

I’m wondering whether this operation breaks the gradient tracking through the network during training. Below I have a simple example where when I print output.grad, I’m consistently getting a gradient value of None.

import torch
import torch.nn as nn
import torch.optim as optim

# Model definition
class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_conv = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.encoder_relu = nn.ReLU()
        self.decoder_conv = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
        self.decoder_relu = nn.ReLU()

    def forward(self, x):
        x = self.encoder_conv(x)
        x = self.encoder_relu(x)
        x = torch.fft.irfft2(x, s=(x.shape[-2], x.shape[-1]))
        x = self.decoder_conv(x)
        x = self.decoder_relu(x)
        return x

# Create an instance of the model
model = DummyModel()

# Set up loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Create a single input and target
input_image = torch.randn(1, 1, 64, 64)  # Real component of input image
target_image = torch.randn(1, 1, 64, 64)  # Real component of target image

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    output_image = model(input_image)
    loss = criterion(output_image, target_image)
    loss.backward()
    optimizer.step()

    # Check the gradient of the output
    print(output_image.grad)

I believe the use of torch.fft is supposed to work with autograd and GPU acceleration but I think I must be handling this in a really stupid way. Any help would be greatly appreciated!

You are receiving a warning, which you shouldn’t ignore:

UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

Calling:

output_image = model(input_image)
output_image.retain_grad()

properly shows the gradient in this non-leaf tensor.

1 Like

Thanks @ptrblck as always!

Just to tie up my understanding… is the fact that using torch.fft.irfft2 is different to using an nn.something operation the root cause of pytorch interpreting everything beyond torch.fft.irfft2 as no longer a leaf Tensor?

Or have I got that completely wrong? :')

No, non-leaf tensors are tensors which were created by an operation and not explicitly without an Autograd history by the user or as e.g. parameters.
Any operation will raise the same warning:

x = torch.randn(1, 1, requires_grad=True)
print(x.is_leaf)
# True

out = F.relu(x)
print(out.is_leaf)
# False

loss = out.mean()
loss.backward()

print(out.grad)
# UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

You don’t need to use nn.Modules and can use any differentiable operation, as also modules are using the same API internally.

Ah, I understand! Thanks @ptrblck !