It is pretty easy to use unsqueeze and cat as workarounds, but stack does not work and throws the following error:
RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
The minimal code to reproduce the error:
params = torch.rand(1, requires_grad=True)
out = torch.stack([params, 1j*params], 1)
loss = out.sum()
However, the following two run fine. First, if the parameters are initialized as cfloat (of course this is not the same behavior and is not a valid solution), it’s fine:
params = torch.rand(1, dtype=torch.cfloat, requires_grad=True)
Or, if we don’t use stack, it’s fine:
out = torch.cat([params.unsqueeze(1), 1j*params.unsqueeze(1)], 1)
I mean… it’s normal.
You are stacking tensors which are of different type.
In the latter example you are concatenating 2 complex tensors. In the former you are stacking complex with float. It neither works with cat if you concatenate a float and a complex number. And i would bet it doesn’t work in general for two different data types.
Just cast them to the proper type before stacking and that’s all. Lazy ppl.
If you’d want PyTorch to add this behavior (i.e. autocasting floats to cfloats when stacking real/complex datatypes), you can open an issue on PyTorch Github, and see if they’ll add it.
thank you for sending this! It seems like a perfectly good bug report with reproducing code.
This is a shortcoming of the backward of torch.stack. If you ping me (t-vi on github) on an issue, I’ll send a PR.
I did not realize casting is the issue but now I see it. The problem is that it DID work with cat (what I meant was that it works when I initialize
cfloat, or I initialize to
float but use
cat), which really confused me. I spent more time writing this bug report than coming up with my workaround, so not sure why it’s lazy.
Hi Thomas! Thanks for offering to take a look. I have created an issue torch.stack backward does not automatically cast real float to complex float · Issue #75852 · pytorch/pytorch · GitHub and tagged you in the comment.