I have a tensor with shape [bn, k, 2]
. The last dimension are coordinates and I want each batch to be sorted independently depending on the y coordinate ([:, :, 0]
). My approach looks something like this:
import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]
print(a)
print(a_sorted)
So far so good, but I now it sorts both batches according to both index lists, so I get 4 batches in total:
a
tensor([[[ 0.5160, 0.3257],
[-1.2410, -0.8361],
[ 1.3826, -1.1308],
[ 0.0338, 0.1665],
[-0.9375, -0.3081]],
[[ 0.4140, -1.0962],
[ 0.9847, -0.7231],
[-0.0110, 0.6437],
[-0.4914, 0.2473],
[-0.0938, -0.0722]]])
a_sorted
tensor([[[[-1.2410, -0.8361],
[-0.9375, -0.3081],
[ 0.0338, 0.1665],
[ 0.5160, 0.3257],
[ 1.3826, -1.1308]],
[[ 0.0338, 0.1665],
[-0.9375, -0.3081],
[ 1.3826, -1.1308],
[ 0.5160, 0.3257],
[-1.2410, -0.8361]]],
[[[ 0.9847, -0.7231],
[-0.0938, -0.0722],
[-0.4914, 0.2473],
[ 0.4140, -1.0962],
[-0.0110, 0.6437]],
[[-0.4914, 0.2473],
[-0.0938, -0.0722],
[-0.0110, 0.6437],
[ 0.4140, -1.0962],
[ 0.9847, -0.7231]]]])
As you can see, I want only the 1st and the 4th batch to be returned. How do I do that?