Thanks for the information.
I can reproduce this error in PyTorch 1.7.0
and get a different error message in 1.8.0.dev20201022
:
def reshape_fortran(x, shape):
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
x = torch.randn(4, 4, 2, dtype=torch.complex64, requires_grad=True)
shape = (8, 4)
out = reshape_fortran(x, shape)
out.mean().backward()
> RuntimeError: mean does not support automatic differentiation for outputs with complex dtype.
Note that complex support is not fully implemented yet, so I would recommend to verify the results of your second approach using some reference values.