How to sort a 3d tensor by coordinates in last dimension

I have a tensor with shape [bn, k, 2]. The last dimension are coordinates and I want each batch to be sorted independently depending on the y coordinate ([:, :, 0]). My approach looks something like this:

import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]

print(a)
print(a_sorted)

So far so good, but I now it sorts both batches according to both index lists, so I get 4 batches in total:

a
tensor([[[ 0.5160,  0.3257],
         [-1.2410, -0.8361],
         [ 1.3826, -1.1308],
         [ 0.0338,  0.1665],
         [-0.9375, -0.3081]],

        [[ 0.4140, -1.0962],
         [ 0.9847, -0.7231],
         [-0.0110,  0.6437],
         [-0.4914,  0.2473],
         [-0.0938, -0.0722]]])

a_sorted
tensor([[[[-1.2410, -0.8361],
          [-0.9375, -0.3081],
          [ 0.0338,  0.1665],
          [ 0.5160,  0.3257],
          [ 1.3826, -1.1308]],

         [[ 0.0338,  0.1665],
          [-0.9375, -0.3081],
          [ 1.3826, -1.1308],
          [ 0.5160,  0.3257],
          [-1.2410, -0.8361]]],


        [[[ 0.9847, -0.7231],
          [-0.0938, -0.0722],
          [-0.4914,  0.2473],
          [ 0.4140, -1.0962],
          [-0.0110,  0.6437]],

         [[-0.4914,  0.2473],
          [-0.0938, -0.0722],
          [-0.0110,  0.6437],
          [ 0.4140, -1.0962],
          [ 0.9847, -0.7231]]]])

As you can see, I want only the 1st and the 4th batch to be returned. How do I do that?

This should work:

a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[torch.arange(a.size(0)).unsqueeze(1), indices]

print(a)
print(a_sorted)
1 Like