Replicating torch.permute with other operations

Since torch.permute is buggy on mps device (Permute followed by torch.nn.functional.interpolate gives wrong results on mps backend · Issue #88183 · pytorch/pytorch · GitHub), I am looking for a generic replacement of permute operations. Obviously, chained invokations of torch.transpose do the job, but they are buggy on mps too. I was thinking of some combination of .view and transpose, but didn’t come up with a generic “formula” yet.

Hi Spiegelball!

You could try einsum():

>>> import torch
>>> torch.__version__
'1.12.0'
>>> t = torch.ones (2, 3, 4, 5)
>>> p = torch.einsum ('ijkl -> jilk', t)
>>> t.shape
torch.Size([2, 3, 4, 5])
>>> p.shape
torch.Size([3, 2, 5, 4])

Having said that, it would not surprise me if einsum() uses permute()
and / or transpose() internally, so you might end up right back in the soup.

Best.

K. Frank

1 Like

Thanks for the input, sadly using einsum leads to the same problem.