I need help to implement batch sorting in Pytorch. Suppose I have a tensor of size 6x492x5 where 6 is the batch size. I want to sort each batch element (i.e. 492x5 dimensional feature) based on last column (i.e. 5th column) of it. Can anyone please help me?
x = torch.randn((6, 492, 5)) sorted_x, _ = x.sort()
But sort function doesn’t do exactly that I want.
I want to sort each batch element based on the last highlighted column. Moreover, the sorting should be descending order. And the output should be like this:
tensor([[[0.3400, 0.6520, 0.0520, 0.5200, 0.200],
[0.0900, 0.0100, 0.1000, 0.2000, 0.1200],
[1.0000, 0.5000, 0.3000, 0.0000, 0.1100]],
You need to consider that each of item from the second dim (492) also has 5 items. How you
want them to be sorted?
In general there is no direct way to do it. but you can store the column you want in a new array then
torch.sort() with which you will get the indices of the sorted items. with that you can use
torch.index_select() and get a sorted tensor. but you need to be careful of the other dimensions not
just the batch and the one through which you want to sort.