Is there an alternative to using torch.diagonal()
which does not alter the ordering of dimensions of a tensor?
For example, I have a 4D tensor, A, of shape (2,4,4,6). I want B = A[:,i,i,:] for all i in range(4).
I did the following in PyTorch:
tmp = torch.diagonal(A,dim1=1,dim2=2)
B = tmp.permute(0,2,1).contiguous()
Is there a way to avoid the permute() operation?
Many thanks.