Hi, I want to sort a 3D tensor’s last dimension. For example, a 2 * 3 * 4 tensor is as follows.

a=tensor([[[1, 0, 2, 4],

[3, 2, 2, 4],

[0, 0, 2, 2]],

[[1, 4, 0, 2],

[1, 2, 4, 3],

[2, 0, 4, 2]]])

I want to sort the last dimension, resulting in a tensor

b=tensor([[[0, 1, 2, 4],

[2, 2, 3, 4],

[0, 0, 2, 2]],

[[0, 1, 2, 4],

[1, 2, 3, 4],

[0, 2, 2, 4]]])

I think I can write a double for-loop to sort the last dimension independently, but the first dimension times the second dimension would actually be as large as 2048 * 12. I’m afraid it would be too slow. Is there any smarter approach to do this?

This should work:

`torch.sort(a, dim=2, stable=True)`

https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort

1 Like

Thank you. It works. Will it be faster than for-loops? Will setting stable sort even improve the speed?

It should be faster than a double loop in python

1 Like