Batch sorting in pytorch?

I need help to implement batch sorting in Pytorch. Suppose I have a tensor of size 6x492x5 where 6 is the batch size. I want to sort each batch element (i.e. 492x5 dimensional feature) based on last column (i.e. 5th column) of it. Can anyone please help me?

x = torch.randn((6, 492, 5))
sorted_x, _ = x.sort()


But sort function doesn’t do exactly that I want. 1112

I want to sort each batch element based on the last highlighted column. Moreover, the sorting should be descending order. And the output should be like this:

tensor([[[0.3400, 0.6520, 0.0520, 0.5200, 0.200],
[0.0900, 0.0100, 0.1000, 0.2000, 0.1200],
[1.0000, 0.5000, 0.3000, 0.0000, 0.1100]],

You need to consider that each of item from the second dim (492) also has 5 items. How you
want them to be sorted?
In general there is no direct way to do it. but you can store the column you want in a new array then
use torch.sort() with which you will get the indices of the sorted items. with that you can use torch.index_select() and get a sorted tensor. but you need to be careful of the other dimensions not
just the batch and the one through which you want to sort.

Bit late to the party, but oh well.
With this you can select the feature you want to sort by.

def sort_by_feature(batch: torch.Tensor, ifeature: int):
    assert 0 <= ifeature <= batch.shape[-1]
    sorted_ftx_idxs = torch.argsort(batch[..., ifeature]).reshape(-1)
    batch_idxs = (
    batch_sorted = batch[batch_idxs, sorted_ftx_idxs, :].reshape(*batch.shape)
    return batch_sorted

Additionally you need to set descending=True in the argsort.