Let’s say I have a 2D tensor `A`

of shape `(N, 2)`

, and I would like to sort its rows as pairs, not each column separately.

In other words, I would like to find an expression which finds a permutation of rows in `A`

, such that if `i < j`

, then I would like this to be true after sorting:

`(A[i, 0] < A[j, 0]) or ((A[i, 0] == A[j, 0]) and (A[i, 1] <= A[j, 1])`

For example, let’s suppose I have the following tensor:

```
a = torch.FloatTensor(
[[5, 5],
[5, 3],
[3, 5],
[6, 4],
[3, 7]])
```

and after sorting by pairs I would like to have:

```
[[3, 5]
[3, 7],
[5, 3],
[5, 5],
[6, 4]])
```

At the same time, `torch.sort`

sorts each column as a separate 1D tensor:

```
In[]: A.sort(axis=0)
Out[]:
torch.return_types.sort(
values=tensor([[3., 3.],
[3., 4.],
[5., 5.],
[5., 5.],
[6., 7.]]),
indices=tensor([[2, 1],
[4, 3],
[0, 0],
[1, 2],
[3, 4]]))
```

Is there any way to do this in PyTorch?

Thank you.