How to quickly inverse a permutation by using PyTorch?

I am confused on how to quickly restore an array shuffled by a permutation.

Example #1:
[x, y, z] shuffled by P: [2, 0, 1], we will obtain [z, x, y],
the corresponding inverse should be P^-1: [1, 2, 0].

Example #2:
[a, b, c, d, e, f] shuffled by P: [5, 2, 0, 1, 4, 3], then we will get [f, c, a, b, e, d],
the corresponding inverse should be P^-1: [2, 3, 1, 5, 4, 0].

I wrote the following codes based on matrix multiplication (the transpose of permutation matrix is its inverse), but this approach is too slow when I utilize it on my model training. Does there exisits a faster implementation?

import torch

n = 10
x = torch.Tensor(list(range(n)))
print('Original array', x)

random_perm_indices = torch.randperm(n).long()
perm_matrix = torch.eye(n)[random_perm_indices].t()
x = x[random_perm_indices]
print('Shuffled', x)

restore_indices = torch.Tensor(list(range(n))).view(n, 1)
restore_indices =
x = x[restore_indices]
print('Restored', x)

Hi Bin!

If you represent your permutation as a vector of permuted integers
(as you do), you may use argsort() to obtain the inverse permutation:

>>> import torch
>>> torch.__version__
>>> p1 = torch.tensor ([2, 0, 1])
>>> torch.argsort (p1)
tensor([1, 2, 0])
>>> p2 = torch.tensor ([5, 2, 0, 1, 4, 3])
>>> torch.argsort (p2)
tensor([2, 3, 1, 5, 4, 0])


K. Frank

1 Like

Oh, that’s what I need. Thank you very much. :smile:

@KFrank 's method is a great ad hoc way to inverse a permutation.

If the argument is rather large (say >=10000 elements) and you know it is a permutation (0…9999) then you could also use indexing:

def inverse_permutation(perm):
    inv = torch.empty_like(perm)
    inv[perm] = torch.arange(perm.size(0), device=perm.device)
    return inv

This will give the same result (again, if you knew the input is a permutation), but on my computer and with 100.000 elements, a quick %timeit benchmark has it faster on the CPU (~4ms vs. ~165µs) and on the GPU (~147µs vs. ~20µs).
In other words, this is something where the O(n log n) complexity of sort vs. the O(n) complexity of index assignment shows for large operands.

The beauty of @KFrank’s solution is that you don’t need to write a new function, it instantly generalizes to batches, too. Unless you are in the CPU, have large operands and are in a hurry or in a very tight loop, the speed difference probably doesn’t matter much.

Best regards


1 Like

Hi Thomas (and Bin)!

This is absolutely correct – using argsort() to invert a permutation
has the sub-optimal O(n log n) complexity, and will definitely matter with
large permutations.


K. Frank

1 Like

Just want to add that if the dim indices are negative (which they could be because we are talking python) one needs to account for that.

So my solution would be:

def invert_permutation(*dims: int):
    n = len(dims)
    dims = [d if d >= 0 else d + n for d in dims]
    return torch.argsort(torch.LongTensor(dims)).tolist()

# passes this test
@pytest.mark.parametrize("p_in,p_expct", [
    ([0, 1, 2], [0, 1, 2]),  # identity
    ([0, 2, 1], [0, 2, 1]),  # swap last 2
    ([-3, -2, 1], [0, 1, 2]),  # identity with neg. ix
    ([0, -1, -2], [0, 2, 1]),  # swap last 2
def test_inverse_permute(p_in, p_expct):
    assert lazy.tensor.invert_permutation(*p_in) == p_expct

1 Like

Here’s a batch version a la @tom above:


perms = torch.tensor([[3, 2, 1, 0],
                      [1, 2, 3, 0]])
src = torch.arange(4)[None].repeat(2, 1)
perms_inv = torch.empty_like(perms)
perms_inv = torch.scatter(perms_inv, dim=1, index=perms, src=src)

Output is:

tensor([[3, 2, 1, 0],
        [3, 0, 1, 2]])

Requires a .repeat of the arange since gather/scatter don’t support broadcasting, but that’s a minor annoyance.

1 Like

Thank you for amending the example.
A minor nit: likely .expand (to the right size) has a chance of being a tad faster than repeat.

Best regards



oh yeah, just came back here to point that out! :sweat_smile: also that expand doesn’t spend extra memory since it returns a view, when repeat returns a copy.

1 Like