Hi there,
I am trying to permute a tensor that is shaped [batch_shape, 4, 4] along its third dimension. The idea is that I want to get a random permutation of (0,1,2,3) indices, i.e. (0,1,3,2), (0,2,1,3), …, (3,1,2,0); and permute each of the elements along the batch with it.
For example, first element may be indexed such that
permuted_tensor[0] = tensor_to_permute[0, :, (3,1,2,0)]
and second elment, such that:
permuted_tensor[1] = tensor_to_permute[1, :, (1,2,0,3)]
Now this is one of those things that would be much easier with a for loop:
for i in range(tensor_to_permute.shape[0]):
permuted_tensor[i] = tensor_to_permute[i, :, torch.randperm(4)]
but I am very reluctant to use it, since the computational time increases notably. Any advice on how to do it?
Thanks!