Torch.argsort Error

The example of argsort on the documentation page torch.argsort — PyTorch 2.3 documentation

has an error on the first line of output. That is, for the first row [ 0.0785, 1.5267, -0.8521, 0.4065], argsort outputs [2, 0, 3, 1]. However, the indices give the descending order, rather than the ascending order. All other rows are ok.

Here is what is on the documentation page (and inputting the same “a” as given, locally, also replicates the issue):

a = torch.randn(4, 4)
a
tensor([[ 0.0785, 1.5267, -0.8521, 0.4065],
[ 0.1598, 0.0788, -0.0745, -1.2700],
[ 1.2208, 1.0722, -0.7064, 1.2564],
[ 0.0669, -0.2318, -0.8229, -0.9280]])

torch.argsort(a, dim=1)
tensor([[2, 0, 3, 1],
[3, 2, 1, 0],
[2, 1, 0, 3],
[3, 2, 1, 0]])

@pymex

[ 0.0785, 1.5267, -0.8521, 0.4065]

Output

 [2, 0, 3, 1]

is correct by ascending
2 → -0.8521
0 → 0.0785
3 → 0.4065
1 → 1.5267

Is there anything that i am missing?

The code is working as it should, perhaps you got confused with the array indices :sweat_smile:

(post deleted by author)

@pymex
Please read the docs
https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

Sometimes it is difficult to wrap your head around argsort. Request you to try out with a 1d array with simple integers

(post deleted by author)

From the docs:

Returns the indices that sort a tensor along a given dimension in ascending order by value.

which is exactly what it does and what @anantguptadbl described:

x = torch.tensor([ 0.0785,  1.5267, -0.8521,  0.4065])
>>> idx = torch.argsort(x)
>>> print(idx)
tensor([2, 0, 3, 1])
>>> x_sorted = x[idx]
>>> print(x_sorted)
tensor([-0.8521,  0.0785,  0.4065,  1.5267])

The returned indices sort the tensor in ascending order.
Further:

This is the second value returned by torch.sort(). See its documentation for the exact semantics of this method.

which is also the case:

>>> torch.sort(x)
torch.return_types.sort(
values=tensor([-0.8521,  0.0785,  0.4065,  1.5267]),
indices=tensor([2, 0, 3, 1]))

@ptrblck The indices map as above. Is this not what occurs when as noted when you write:

x = torch.tensor([ 0.0785, 1.5267, -0.8521, 0.4065])

idx = torch.argsort(x)
print(idx)
tensor([2, 0, 3, 1])

In that case, when you traverse the indices from 0 to 3, how then does the mapping yield the values in ascending order when the mapping, as noted above, shows the values descending. Perhaps you can clarify this issue.

argsort() returns the the ordering of the original list’s indeces.

In your example, after running idx = torch.argsort(x), the value of idx is tensor([2, 0, 3, 1]). This indicates that the sorted elements of x should appear in the following order: x[2], x[0], x[3], x[1]. With original values of,

x = [0.0785, 1.5267, -0.8521, 0.4065]
#       0       1        2       3

the sorted order should be:

x[idx] = [-0.8521, 0.0785, 0.4065, 1.5267]
#             2       0       3       1

Check my posted code snippets, which shows:

>>> x_sorted = x[idx]
>>> print(x_sorted)
tensor([-0.8521,  0.0785,  0.4065,  1.5267])

and matches what the docs say:

Returns the indices that sort a tensor along a given dimension in ascending order by value.

and which was already clarified multiple times.

(post deleted by author)

I don’t think you are understanding the documentation and I don’t know how to explain it any further besides reciting what the simple sentence in the docs already states.
The returned indices sort the passed input tensor as clearly described. The same indices are returned via torch.sort and correspond to the sorted values.
Three different users tried to explain this behavior, so unsure how to further help.

For others in case you are visiting this topic: the docs are correct and do not need any changes.

1 Like

Ok. The confusing part is the documentation. At first glance, the last three rows give the impression that the indices of the elements are those of the sorted input tensor. After all, three rows give this impression and one row seems out of sync. However, that is not the case.

The documentation requires further digging by saying that the indices are the “second value returned by torch.sort(). See its documentation” After looking this up, it becomes clear that the indices are the “are the indices of the elements in the original input tensor” rather than the sorted tensor.

It would be well to simply state in the documentation that the returned indices are the “indices of the elements in the original input tensor” rather than saying “second value returned by torch.sort(). See its documentation”. After all, I’ve seen this same issue brought up before in other posts (but the replies and answers were not clear). A simple fix would be to directly state this fact on the page rather than require additional digging. It is a better practice to be as clear as possible up front rather than to require readers to dig further than needed, especially when three of the four rows give a different impression of what the expected output should be. It makes one row seem wrong.