Autograd error on complex tensor, PyTorch 1.6

Hi, I was excited to see PyTorch announcing support for complex numbers, including the statement that

PyTorch supports autograd for complex tensors. The autograd APIs can be used for both holomorphic and non-holomorphic functions.

But, then to test this out I tried something which should be super simple:

import torch

class ModuleWithComplexLayer(torch.nn.Module):
  def __init__(self):
    super(ModuleWithComplexLayer, self).__init__()
    self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=5,
        kernel_size=(4, 4), stride=(2, 2))
    self.dense_real = torch.nn.Linear(5*9*9, 5*9*9)
    self.dense_imag = torch.nn.Linear(5*9*9, 5*9*9)
    self.deconv1 = torch.nn.ConvTranspose2d(in_channels=5, out_channels=3,
        kernel_size=(4, 4), stride=(2, 2))

  def forward(self, x):
    x = self.conv1(x)
    x = torch.abs(self.dense_real(x.view((-1, 5*9*9))) +
                  1j*self.dense_imag(x.view((-1, 5*9*9))))  # <-- doesn't work
    # x = self.dense_real(x.view((-1, 5*9*9)))  # <-- works
    x = x.view((-1, 5, 9, 9))
    return self.deconv1(x)

torch_device = torch.device('cuda:0')

dummy_data = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(
  torch.randn([10, 3, 20, 20]).to(torch_device)), batch_size=5)

#Instantiate the model
model = ModuleWithComplexLayer().to(torch_device)

#Loss function
loss_object = torch.nn.MSELoss()

#Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

EPOCHS = 5
for epoch in range(EPOCHS):
  for images in dummy_data:
    optimizer.zero_grad()
    outputs = model(images[0])
    loss = loss_object(outputs, images[0])
    loss.backward()
    optimizer.step()
  print('Epoch {}, Loss: {}'.format(epoch+1, loss.item()))

But this throws the following autograd error:

Traceback (most recent call last):                                                                                                                            
  File "complex_number_error_example.py", line 43, in <module>                                                                                                
    loss.backward()                                                                                                                                           
  File "/home/me/miniconda3/envs/PyTorch1.6+SciComputing/lib/python3.6/site-packages/torch/tensor.py", line 185, in backward                         
    torch.autograd.backward(self, gradient, retain_graph, create_graph)                                                                                       
  File "/home/me/miniconda3/envs/PyTorch1.6+SciComputing/lib/python3.6/site-packages/torch/autograd/__init__.py", line 127, in backward              
    allow_unreachable=True)  # allow_unreachable flag                                                                                                         
RuntimeError: "sign_cuda" not implemented for 'ComplexFloat'                                                                                                  
Exception raised from operator() at /opt/conda/conda-bld/pytorch_1595629427286/work/aten/src/ATen/native/cuda/UnarySignKernels.cu:44 (most recent call first):

... and then some more scary-looking errors form .so files...

What am I missing here? It would appear that autograd does not support complex numbers.

First the happy news: If you use the nightlies or in the upcoming 1.7 release the backward of abs is implemented and works.

If you’re into the details of what happens: abs doesn’t support complex numbers in its backward in 1.6: it uses the formula dabs/dx = sign x , which is, of course, bogus for complex numbers. While trying the backward with this wrong formula it tries to take the sign of a complex number and that isn’t a thing, so you get the error you are seeing.which is the error you are seeing).

Best regards

Thomas

If this is the case for abs, how can the PyTorch developers possibly claim support for complex numbers in version 1.6? My suspicion is I’m not the first person to spend significant time trying to migrate code from TensorFlow to PyTorch on the basis of this false claim. That page should be taken down, or significant disclaimers added.

While I can feel your frustration and agree that the big fat red warning might be better to talk about experimental rather than beta feature and suggest more strongly that there are bugs, it would seem that there it is a note that this isn’t ready for production.

The unfortunate truth is that new code is susceptible to bugs and there are many functions that have to be checked for necessary adjustments. Because it can happen that backwards are not implemented in a way that does not work well with complex numbers, there has been an effort to disable unchecked backwards on complex by default, and re-enable them after checking, so you might get more pertinent error messages.

If its any comfort, I might be able to get you a discount for the next PyTorch version, PyTorch 1.7, to be released this month.

Best regards

Thomas

I met this problem too and I’m so sad to hear that. Hope v1.7 can be release soon.