I’m wondering if PyTorch exposes the radix sort API or is there an efficient way to implement radix sort?
For example, I have a 2-D tensor:
[[1, 3, 2, 2, 1, 1], --> row 1
[6, 7, 8, 8, 8, 7], --> row 2
[9, 10, 10, 9, 9, 9]]
If I perform a radix sort on the tensor (row 1 is the most important dim, if the elements are the same, they are ordered by row 2), I can get
[[1, 1, 1, 2, 2, 3], --> row 1
[6, 7, 8, 8, 8, 7], --> row 2
[9, 9, 9, 9, 10, 10]]
Any suggestions on implementing this efficiently in PyTorch? Thanks!