Error on calling backward() on complex tensor

Hi, I got the following error when I called backward on a complex tensor. I would like to know if there is a way to solve it.


RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

I have provided a very simple example below for which I am facing the error. The way I am feeding complex inputs to my network is by separating the real and imaginary part into different channels. Once I get the output I arrange the output again into a complex tensor and call backward based on the computation.


> y1 = net(x1)
> y2 = net(x2)
> complex_y1 = (y1[0, :] + (1j * y1[1, :]))
> complex_y2 = (y2[0, :] + (1j * y2[1, :]))
> loss= (complex_y1 - complex_y2 ) * torch.conj(complex_y1 - complex_y2 )
> loss.backward(torch.ones(loss.shape, dtype=torch.complex64).cuda())