About the permutation returned by `torch.sort`

Hi !
Say I want to use sort this list with torch: [1, 5, 3, 4].
The following code in Python 3.x:

from torch import tensor
import torch

t = tensor([1,5,3,4])
t = torch.sort(t, descending=True)

print(t)

returns

torch.return_types.sort(
values=tensor([5, 4, 3, 1]),
indices=tensor([1, 3, 2, 0]))

The list indices from the output is a permutation of 0, 1, 2, 3. It maps the output values [5, 4, 3, 1] to the input list [1, 5, 3, 4]. Indeed : 5 goes to position 1, 4 to position 3, 3 to position 2 and 1 to position 0 : [1, 3, 2, 0].

What would be a simple and efficient way to get the inverse of this permutation, that is the permutation that maps the input list [1, 5, 3, 4] to the output list [5, 4, 3, 1] ? This inverse is in our case: [3, 0, 2, 1].

Thanks a lot !

I had this old code, based on rnn.utils:

def invert_permutation(permutation):
    output = torch.empty_like(permutation)
    output.scatter_(0, permutation,
                    torch.arange(0, len(permutation), dtype=torch.int64, device=permutation.device))
    return output
1 Like

Hi Damien (and Alex)!

In my mind, the cleanest approach is to take the argsort of the
permutation:

import torch
torch.__version__

t = torch.LongTensor ([1, 5, 3, 4])
perm = torch.sort (t, descending = True)[1]
perm
inverse = torch.sort (perm)[1]   # invert permutation
inverse
perm[inverse]    # check inverse
inverse[perm]    # check inverse again

Here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> t = torch.LongTensor ([1, 5, 3, 4])
>>> perm = torch.sort (t, descending = True)[1]
>>> perm

 1
 3
 2
 0
[torch.LongTensor of size 4]

>>> inverse = torch.sort (perm)[1]   # invert permutation
>>> inverse

 3
 0
 2
 1
[torch.LongTensor of size 4]

>>> perm[inverse]    # check inverse

 0
 1
 2
 3
[torch.LongTensor of size 4]

>>> inverse[perm]    # check inverse again

 0
 1
 2
 3
[torch.LongTensor of size 4]

This is, however, an n log (n) algorithm. Nonetheless, by using
pytorch’s built-in sort function, it ought to be quite fast.

Best.

K. Frank

2 Likes

Thanks @KFrank, taking the argsort of a permutation (in increasing order) gives its inverse! :slight_smile: