How to sort 2d tensors by row-wise pairs

I am trying to sort a tensor first column-wise, preserve this sorting, and sort the second column accordingly. The tensors are always 2d, and they contain only torch.int64 elements.

The intuitive way to think about it is that each row represents a sample, and I want to sort these samples in ascending order, first by the first column and then by the second.

For example, if I have a tensor:

a = torch.tensor([ [3,0], [2,5], [2,4], [3,9] ])

I would like the output to be:

torch.tensor([ [2,4], [2,5], [3,0], [3,9] ])

I suspect the “stable” argument of the torch.sort function can help me achieve this, but I’m not sure I understand the documentation.

Hi Victor!

Yes, stable will be part of the approach.

Because you want to reorder the rows themselves, rather than the elements within
a given column, you will need to index into a with the sorted-order indices given by
argsort().

Then, using a standard approach for sorting by a “primary” and “secondary” key,
first sort by the secondary key, and then perform a stable sort with the primary key.

Thus:

>>> import torch
>>> torch.__version__
'2.3.0'
>>> a = torch.tensor([ [3,0], [2,5], [2,4], [3,9] ])
>>> b = a[a[:, 1].argsort (dim = 0, stable = False)]
>>> b[b[:, 0].argsort (dim = 0, stable = True)]
tensor([[2, 4],
        [2, 5],
        [3, 0],
        [3, 9]])

Best.

K. Frank

1 Like