I have a PyTorch model with a custom forward pass that involves applying torch.fft.irfft2
to the real component of a complex input tensor.
I’m wondering whether this operation breaks the gradient tracking through the network during training. Below I have a simple example where when I print output.grad
, I’m consistently getting a gradient value of None
.
import torch
import torch.nn as nn
import torch.optim as optim
# Model definition
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder_conv = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.encoder_relu = nn.ReLU()
self.decoder_conv = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
self.decoder_relu = nn.ReLU()
def forward(self, x):
x = self.encoder_conv(x)
x = self.encoder_relu(x)
x = torch.fft.irfft2(x, s=(x.shape[-2], x.shape[-1]))
x = self.decoder_conv(x)
x = self.decoder_relu(x)
return x
# Create an instance of the model
model = DummyModel()
# Set up loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Create a single input and target
input_image = torch.randn(1, 1, 64, 64) # Real component of input image
target_image = torch.randn(1, 1, 64, 64) # Real component of target image
# Training loop
for epoch in range(10):
optimizer.zero_grad()
output_image = model(input_image)
loss = criterion(output_image, target_image)
loss.backward()
optimizer.step()
# Check the gradient of the output
print(output_image.grad)
I believe the use of torch.fft
is supposed to work with autograd and GPU acceleration but I think I must be handling this in a really stupid way. Any help would be greatly appreciated!