# Batched permutation

Hi there,

I am trying to permute a tensor that is shaped [batch_shape, 4, 4] along its third dimension. The idea is that I want to get a random permutation of (0,1,2,3) indices, i.e. (0,1,3,2), (0,2,1,3), …, (3,1,2,0); and permute each of the elements along the batch with it.

For example, first element may be indexed such that

``````permuted_tensor = tensor_to_permute[0, :, (3,1,2,0)]
``````

and second elment, such that:

``````permuted_tensor = tensor_to_permute[1, :, (1,2,0,3)]
``````

Now this is one of those things that would be much easier with a for loop:

``````for i in range(tensor_to_permute.shape):
permuted_tensor[i] = tensor_to_permute[i, :, torch.randperm(4)]
``````

but I am very reluctant to use it, since the computational time increases notably. Any advice on how to do it?

Thanks!

Hi Javier!

The problem is that pytorch (still) does not offer a batch version of
`randperm()`. The standard work-around is to compute `argsort()`
for a batch of random vectors.

Thus:

``````>>> _ = torch.manual_seed (2023)
>>>
>>> tensor_to_permute = torch.randint (100, (3, 4, 4))   # some test data
>>> perms = torch.rand (3, 4).argsort (dim = 1)          # a way to get a batch of permutations
>>>
>>> permuted_tensor = torch.empty_like (tensor_to_permute)
>>> for i in range(tensor_to_permute.shape):
...     permuted_tensor[i] = tensor_to_permute[i, :, perms[i]]
...
>>> permuted_tensorB = tensor_to_permute.gather (2, perms.unsqueeze (1).expand (-1, 4, -1))
>>>
>>> torch.equal (permuted_tensor, permuted_tensorB)
True
``````

`gather()`, if you prefer that syntax.)
This is mildly inefficient in that the cost of sorting a vector of length `n` is
`O (n log (n))` while the cost of generating a random permutation of
length `n` is `O (n)`. In your example, `n = 4`, so that `log (n)` factor is
irrelevant, but as `n` becomes large, it could matter. However, the slowdown
due to the python loop will be worse than the `log (n)` unless `batch_shape`
is small and `n` is quite larger.