How to get sorted indices

According to the documentation of argsort

a = torch.randn(4, 4)
a
tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
        [ 0.1598,  0.0788, -0.0745, -1.2700],
        [ 1.2208,  1.0722, -0.7064,  1.2564],
        [ 0.0669, -0.2318, -0.8229, -0.9280]])

torch.argsort(a, dim=1)
tensor([[2, 0, 3, 1],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])

However, I want the sorted indices. That is

tensor([[1, 3, 0, 2],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])

Hi Z!

Slap on a second argsort() to get (the indices of) your sorted indices:

>>> import torch
>>> torch.__version__
'1.12.0'
>>> a = torch.tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
...         [ 0.1598,  0.0788, -0.0745, -1.2700],
...         [ 1.2208,  1.0722, -0.7064,  1.2564],
...         [ 0.0669, -0.2318, -0.8229, -0.9280]])
>>> a.argsort (dim = 1)
tensor([[2, 0, 3, 1],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])
>>> a.argsort (dim = 1).argsort (dim = 1)
tensor([[1, 3, 0, 2],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])

Best.

K. Frank