Autograd does not work for torch.stack on complex tensor

It is pretty easy to use unsqueeze and cat as workarounds, but stack does not work and throws the following error:

RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

The minimal code to reproduce the error:

params = torch.rand(1, requires_grad=True)
out = torch.stack([params, 1j*params], 1)
loss = out.sum()
print(loss)
loss.backward()

However, the following two run fine. First, if the parameters are initialized as cfloat (of course this is not the same behavior and is not a valid solution), it’s fine:

params = torch.rand(1, dtype=torch.cfloat, requires_grad=True)

Or, if we don’t use stack, it’s fine:

out = torch.cat([params.unsqueeze(1), 1j*params.unsqueeze(1)], 1)

I mean… it’s normal.
You are stacking tensors which are of different type.
In the latter example you are concatenating 2 complex tensors. In the former you are stacking complex with float. It neither works with cat if you concatenate a float and a complex number. And i would bet it doesn’t work in general for two different data types.

Just cast them to the proper type before stacking and that’s all. Lazy ppl.

1 Like

If you’d want PyTorch to add this behavior (i.e. autocasting floats to cfloats when stacking real/complex datatypes), you can open an issue on PyTorch Github, and see if they’ll add it.

1 Like

Hi,

thank you for sending this! It seems like a perfectly good bug report with reproducing code.
This is a shortcoming of the backward of torch.stack. If you ping me (t-vi on github) on an issue, I’ll send a PR.

Best regards

Thomas

1 Like

I did not realize casting is the issue but now I see it. The problem is that it DID work with cat (what I meant was that it works when I initialize params to cfloat, or I initialize to float but use cat), which really confused me. I spent more time writing this bug report than coming up with my workaround, so not sure why it’s lazy.

Hi Thomas! Thanks for offering to take a look. I have created an issue torch.stack backward does not automatically cast real float to complex float · Issue #75852 · pytorch/pytorch · GitHub and tagged you in the comment.