TLDR…
Is it possible to have a network which performs complex convolutions on complex tensors, as well as real convolutions on non-complex tensors, in the same model?
Deeper dive…
Say we have the following model:
import torch
import torch.nn as nn
class ComplexNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 64, 4096)
self.fc2 = nn.Linear(4096, 64 * 64)
self.Hardtanh = nn.Hardtanh()
def forward(self, x):
x = torch.fft.ifft2(torch.fft.fftshift(x))
x_real = x.real
x_imag = x.imag
### will this work? --> x = torch.sqrt(x_real**2 + x_imag**2)
x = self.conv1(x)
x = self.Hardtanh(x)
x = self.conv2(x)
x = self.Hardtanh(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.Hardtanh(x)
x = self.fc2(x)
x = self.Hardtanh(x)
x = x.view(x.size(0), 1, 64, 64)
x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1)))
return x
The input is a complex tensor, and to handle that fact, we can explicitly cast the model to be of complex dtype, prior to training:
model = model.type(torch.complex64)
However, the output of x = torch.fft.ifft2(torch.fft.fftshift(x))
(i.e. moving from the complex to the image plane) is also complex. The subsequent convolutions are also complex thanks to our model casting above.
Is there a way to set the convolutions in the image plane to be real convolutions only, while still allowing gradients to pass through the initial x = torch.fft.ifft2(torch.fft.fftshift(x))
and final x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1)))
operations thanks to the use of a complex cast model?
One way I’ve thought about is to force the output of x = torch.fft.ifft2(torch.fft.fftshift(x))
to be a real image using x = torch.sqrt(x_real**2 + x_imag**2)
. But the convolution weights and biases will still be complex thanks to our model casting so it will throw an error for mismatched dtypes. Perhaps you can also perform the same magnitude operation on these parameter attributes?
Or is there a much better way of doing this and I’m sort of butchering PyTorch with this example?
Thanks for any advice!