Let’s say I have a 2D tensor A
of shape (N, 2)
, and I would like to sort its rows as pairs, not each column separately.
In other words, I would like to find an expression which finds a permutation of rows in A
, such that if i < j
, then I would like this to be true after sorting:
(A[i, 0] < A[j, 0]) or ((A[i, 0] == A[j, 0]) and (A[i, 1] <= A[j, 1])
For example, let’s suppose I have the following tensor:
a = torch.FloatTensor(
[[5, 5],
[5, 3],
[3, 5],
[6, 4],
[3, 7]])
and after sorting by pairs I would like to have:
[[3, 5]
[3, 7],
[5, 3],
[5, 5],
[6, 4]])
At the same time, torch.sort
sorts each column as a separate 1D tensor:
In[]: A.sort(axis=0)
Out[]:
torch.return_types.sort(
values=tensor([[3., 3.],
[3., 4.],
[5., 5.],
[5., 5.],
[6., 7.]]),
indices=tensor([[2, 1],
[4, 3],
[0, 0],
[1, 2],
[3, 4]]))
Is there any way to do this in PyTorch?
Thank you.