According to the documentation of argsort
a = torch.randn(4, 4)
a
tensor([[ 0.0785, 1.5267, -0.8521, 0.4065],
[ 0.1598, 0.0788, -0.0745, -1.2700],
[ 1.2208, 1.0722, -0.7064, 1.2564],
[ 0.0669, -0.2318, -0.8229, -0.9280]])
torch.argsort(a, dim=1)
tensor([[2, 0, 3, 1],
[3, 2, 1, 0],
[2, 1, 0, 3],
[3, 2, 1, 0]])
However, I want the sorted indices. That is
tensor([[1, 3, 0, 2],
[3, 2, 1, 0],
[2, 1, 0, 3],
[3, 2, 1, 0]])