RuntimeError: For complex Tensors, both grad_output and output are required to have the same dtype. Mismatch in dtype: grad_output[0] has a dtype of torch.complex64 and output[0] has a dtype of torch.float32

I wish to build a neural network that is both amenable to autograd and includes complex-values for input and output. For the latter, I applied torch.cfloat to change the dtypes in the neural network. However, I’m struggling to fix this error which I get after running my code. This is a snippet:

import numpy as np
import torch
import torch.optim as optim
from torch.autograd import grad

class mySin(torch.nn.Module):
    def forward(input):
        return torch.sin(input)

class Net(torch.nn.Module):
    def __init__(self, D_hid=10):

        self.actF = mySin()
        self.Ein    = torch.nn.Linear(1,1, dtype=torch.cfloat)
        self.Lin_1  = torch.nn.Linear(2, D_hid, dtype=torch.cfloat)
        self.Lin_2  = torch.nn.Linear(D_hid, D_hid, dtype=torch.cfloat)
        self.out    = torch.nn.Linear(D_hid, 1, dtype=torch.cfloat)

    def forward(self,t):
        In1 = self.Ein(torch.ones_like(t, dtype=torch.cfloat))
        L1 = self.Lin_1(,In1),1))
        h1 = self.actF(L1)
        L2 = self.Lin_2(h1)
        h2 = self.actF(L2)
        out = self.out(h2)
        return out, In1

def dfx(x,f):
    # Calculate the derivative with auto-differention
    return grad([f], [x], grad_outputs=torch.ones(x.shape, dtype=dtype), create_graph=True)[0]

# differential equation residual 
def diffeq_residual(t,psi, E):
    psi_dx = dfx(t,psi)
    psi_ddx= dfx(t,psi_dx)
    f  = (psi_ddx)/2 + E*(psi)
    L  = (f.pow(2)).mean(); 
    return L 

# testing the neural network
net = Net()
input = torch.rand(4).reshape(-1,1)
input.requires_grad = True
nn, En = net(input)
Loss = diffeq_residual(input, nn, En)


I have a feeling the issue here is that your input is initialized in float32 (PyTorch’s default dtype, if you want to change it across the entire script you need to use torch.set_default_dtype(torch.cfloat).

What you could try is to cast your input variable into cfloat and pass it through the network via,

input = torch.rand(4).reshape(-1,1)
input = torch.complex(input, torch.zeros_like(input))
input.requires_grad = True
nn, En = net(input)
Loss = diffeq_residual(input, nn, En)

and see if that works!

Thank you very much. Your answer helped me resolve the issue. :clap:

1 Like

I am using google colab. When I try this command I get the following: “Your session crashed for an unknown reason.” This seems to occur for “cfloat” and “complex64”, but the code runs fine for “float32” and “float64.” What could be the issue here?

It seems that it only accepts real dtypes, which I guess makes sense as most of the base mathematics will be done using real numbers.

Reading the docs says that the default dtype is used to infer the dtype of complex values, so I’d assume it has to be a real based dtype, i.e. torch.float32, torch.float64 etc…