How to sort the last dimension independently of a 3D tensor?

Hi, I want to sort a 3D tensor’s last dimension. For example, a 2 * 3 * 4 tensor is as follows.
a=tensor([[[1, 0, 2, 4],
[3, 2, 2, 4],
[0, 0, 2, 2]],
[[1, 4, 0, 2],
[1, 2, 4, 3],
[2, 0, 4, 2]]])
I want to sort the last dimension, resulting in a tensor
b=tensor([[[0, 1, 2, 4],
[2, 2, 3, 4],
[0, 0, 2, 2]],
[[0, 1, 2, 4],
[1, 2, 3, 4],
[0, 2, 2, 4]]])
I think I can write a double for-loop to sort the last dimension independently, but the first dimension times the second dimension would actually be as large as 2048 * 12. I’m afraid it would be too slow. Is there any smarter approach to do this?

This should work:
torch.sort(a, dim=2, stable=True)

https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort

1 Like

Thank you. It works. Will it be faster than for-loops? Will setting stable sort even improve the speed?

It should be faster than a double loop in python

1 Like