Sorting 2D tensor by pairs, not columnwise

Let’s say I have a 2D tensor A of shape (N, 2), and I would like to sort its rows as pairs, not each column separately.
In other words, I would like to find an expression which finds a permutation of rows in A, such that if i < j, then I would like this to be true after sorting:
(A[i, 0] < A[j, 0]) or ((A[i, 0] == A[j, 0]) and (A[i, 1] <= A[j, 1])

For example, let’s suppose I have the following tensor:

a = torch.FloatTensor(
   [[5, 5],
    [5, 3],
    [3, 5],
    [6, 4],
    [3, 7]])

and after sorting by pairs I would like to have:

[[3, 5]
 [3, 7],
 [5, 3],
 [5, 5],
 [6, 4]])

At the same time, torch.sort sorts each column as a separate 1D tensor:

In[]: A.sort(axis=0)
Out[]: 
torch.return_types.sort(
values=tensor([[3., 3.],
        [3., 4.],
        [5., 5.],
        [5., 5.],
        [6., 7.]]),
indices=tensor([[2, 1],
        [4, 3],
        [0, 0],
        [1, 2],
        [3, 4]]))

Is there any way to do this in PyTorch?
Thank you.

Hi,

Do you know the maximum value that can be contained in the Tensor? Can the values be negative?

Hi @albanD,
Let’s suppose that I know the maximum value and the values cannot be negative (I can always renormalize values in my tensor this way). The values can be of float type.

If you know the max value as M, then you can do:

augmented_a = a.select(1, 0) * M + a.select(1, 1)
ind = augmented_a.sort().indices

res = a.index_select(0, ind)

(I didn’t test it but that should work :slight_smile:)

If you don’t know the max, you can play similar tricks if you have a max number of digit and converting them to strings :wink:

1 Like

Thank you!
This works indeed, but I think it can result in some precision loss in some cases. For example, in my particular case the first column has integer values (of type long) and the second column has floating-point type values (float32). When I construct augmented_a, I get a floating-point type 1D array, and only integers in [-16777216, 16777216] can be represented in float32 without precision loss (according to Wikipedia). a.select(1, 0) * M can contain large numbers like these if M is large.
Maybe there is a more suitable way?

Yes, this is only the simple solution in the case where numbers are small.
It has the advantage to be fast and only using tensor-wide operations.

I’m not sure how to do this efficiently though.
If you find an efficient algorithm that mostly uses matrices (from numpy or matlab) that does this, do link it here, I’m sure we can adapt it to pytorch.

In NumPy I could do this via sorting by the 2nd column then stable sorting by the 1st column:

In [54]: a = np.array(
    ...:     [[5, 5],
    ...:      [5, 3],
    ...:      [3, 5],
    ...:      [6, 4],
    ...:      [3, 7]])
    ...:      

In [55]: inner_sorting = np.argsort(a[:, 1])    # here it does not matter whether we do it in stable way or not

In [56]: a_inner_sorted = a[inner_sorting]

In [57]: a_inner_sorted
Out[57]: 
array([[5, 3],
       [6, 4],
       [5, 5],
       [3, 5],
       [3, 7]])

In [58]: outer_sorting = np.argsort(a_inner_sorted[:, 0], kind='stable')

In [59]: a_outer_sorted = a_inner_sorted[outer_sorting]

In [60]: a_outer_sorted
Out[60]: 
array([[3, 5],
       [3, 7],
       [5, 3],
       [5, 5],
       [6, 4]])

However I have not found any stable sort in PyTorch. Maybe there is a way to perform stable sort in PyTorch? If so, this method would work for columns of any type I think.

1 Like

Right, unfortunately we don’t have stable sort implemented.
There are already some discussion on adding stable sorts (for topk for example: https://github.com/pytorch/pytorch/issues/27542).
If this is really a blocker for you, you can open a new issue to add a stable sort explaining your use case.

1 Like

Great, I will post an issue on GitHub. I could get by with the simple solution, but it might lead to precision loss, so it’s not suitable for me in all cases.
Thank you!

1 Like

In case someone comes looking here, the NumPy solution mentioned by @seva100 now works with PyTorch. Just use the stable=True keyword argument.