Batched permutation

Hi there,

I am trying to permute a tensor that is shaped [batch_shape, 4, 4] along its third dimension. The idea is that I want to get a random permutation of (0,1,2,3) indices, i.e. (0,1,3,2), (0,2,1,3), …, (3,1,2,0); and permute each of the elements along the batch with it.

For example, first element may be indexed such that

permuted_tensor[0] = tensor_to_permute[0, :, (3,1,2,0)]

and second elment, such that:

permuted_tensor[1] = tensor_to_permute[1, :, (1,2,0,3)]

Now this is one of those things that would be much easier with a for loop:

for i in range(tensor_to_permute.shape[0]):
     permuted_tensor[i] = tensor_to_permute[i, :, torch.randperm(4)]

but I am very reluctant to use it, since the computational time increases notably. Any advice on how to do it?

Thanks!

Hi Javier!

The problem is that pytorch (still) does not offer a batch version of
randperm(). The standard work-around is to compute argsort()
for a batch of random vectors.

Thus:

>>> _ = torch.manual_seed (2023)
>>>
>>> tensor_to_permute = torch.randint (100, (3, 4, 4))   # some test data
>>> perms = torch.rand (3, 4).argsort (dim = 1)          # a way to get a batch of permutations
>>>
>>> permuted_tensor = torch.empty_like (tensor_to_permute)
>>> for i in range(tensor_to_permute.shape[0]):
...     permuted_tensor[i] = tensor_to_permute[i, :, perms[i]]
...
>>> permuted_tensorB = tensor_to_permute.gather (2, perms.unsqueeze (1).expand (-1, 4, -1))
>>>
>>> torch.equal (permuted_tensor, permuted_tensorB)
True

(You could use pytorch tensor indexing (“advanced indexing”) instead of
gather(), if you prefer that syntax.)

This is mildly inefficient in that the cost of sorting a vector of length n is
O (n log (n)) while the cost of generating a random permutation of
length n is O (n). In your example, n = 4, so that log (n) factor is
irrelevant, but as n becomes large, it could matter. However, the slowdown
due to the python loop will be worse than the log (n) unless batch_shape
is small and n is quite larger.

Best.

K. Frank