Hi, I’m trying to implement a Fourier unit as a part of a model for object detection. The original code written by the paper’s author uses an old version of pytorch that using the torch.rfft() and torch.irfft() methods, which are replaced by torch.fft.transform type in the newer versions.
I tried implementing the same using the newer methods. Code is as follows:
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1):
super(FourierUnit, self).__init__()
self.groups = groups
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
batch, c, h, w = x.size()
r_size = x.size()
# (batch, c, h, w/2+1, 2)
ffted = torch.view_as_real(torch.fft.fft2(x, norm='ortho'))
# (batch, c, 2, h, w/2+1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
ffted = self.relu(self.bn(ffted))
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted=torch.view_as_complex(ffted)
output = torch.fft.ifft2(ffted, s=r_size[2:], norm='ortho')
return output
But when i try using this unit with another model, the following error is thrown:
Input type (CUDAComplexFloatType) and weight type (torch.cuda.FloatTensor) should be the same
How do I go about dealing with this? Thanks in advance.