Problem with complex tensors

Hi, I’m trying to implement a Fourier unit as a part of a model for object detection. The original code written by the paper’s author uses an old version of pytorch that using the torch.rfft() and torch.irfft() methods, which are replaced by torch.fft.transform type in the newer versions.

I tried implementing the same using the newer methods. Code is as follows:

class FourierUnit(nn.Module):
    def __init__(self, in_channels, out_channels, groups=1):
        super(FourierUnit, self).__init__()
        self.groups = groups
        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) = torch.nn.BatchNorm2d(out_channels * 2)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        batch, c, h, w = x.size()
        r_size = x.size()
        # (batch, c, h, w/2+1, 2)
        ffted = torch.view_as_real(torch.fft.fft2(x, norm='ortho'))
        # (batch, c, 2, h, w/2+1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        output = torch.fft.ifft2(ffted, s=r_size[2:], norm='ortho')

        return output

But when i try using this unit with another model, the following error is thrown:
Input type (CUDAComplexFloatType) and weight type (torch.cuda.FloatTensor) should be the same

How do I go about dealing with this? Thanks in advance.

You’re trying to pass a complex type value into a float type value, and you can’t mix dtypes like that. This is probably occuring in this layer,

So, you need to find a way to define conv_layer with complex dtype support.

Thanks for the answer. Since i already use view_as_real() before passing the output to the conv layer, wont it just be dealing with a 5 dimensional real input?

Can you print the error in full? Surely it must tell you what line that error is at?

I added the Fourier Unit to another network.

    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2,2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24*10*10, 10)

        self.FU = FourierUnit(12,12)

    def forward(self, input):
        output = F.relu(self.bn1(self.conv1(input)))
        output = self.FU(output)   
        output = F.relu(self.bn2(self.conv2(output)))     
        output = self.pool(output)                        
        output = F.relu(self.bn4(self.conv4(output)))     
        output = F.relu(self.bn5(self.conv5(output)))     
        output = output.view(-1, 24*10*10)
        output = self.fc1(output)

        return output

The error is thrown when I try to train the network.

RuntimeError                              Traceback (most recent call last)
<ipython-input-12-859a81d32bfb> in <module>()
      3     # Let's build our model
----> 4     train(5)
      5     print('Finished Training')

5 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/ in _conv_forward(self, input, weight, bias)
    452                             _pair(0), self.dilation, self.groups)
    453         return F.conv2d(input, weight, bias, self.stride,
--> 454                         self.padding, self.dilation, self.groups)
    456     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Input type (CUDAComplexFloatType) and weight type (torch.cuda.FloatTensor) should be the same