I happened to run into this, I have implemented a lexsort based on stable sort (undocumented behavior) here.
Also, I give no guarantees, but I think you can use torch.unique
:
def torch_lexsort(a, dim=-1):
assert dim == -1 # Transpose if you want differently
assert a.ndim == 2 # Not sure what is numpy behaviour with > 2 dim
# To be consistent with numpy, we flip the keys (sort by last row first)
a_unq, inv = torch.unique(a.flip(0), dim=dim, sorted=True, return_inverse=True)
return torch.argsort(inv)
# Make random float vector with duplicates to test if it handles floating point well
vals = torch.rand(4)
a = vals[(torch.rand(3, 9) * 4).long()]
print(a)
ind = torch_lexsort(a)
ind_np = torch.from_numpy(np.lexsort(a.numpy()))
print("Torch ind", ind)
print("Numpy ind", ind)
print("Torch result")
print (a[:, ind])
print("Numpy result")
print (a[:, ind_np])
This gives
tensor([[0.5500, 0.5346, 0.9881, 0.5500, 0.5500, 0.5346, 0.9881, 0.9881, 0.2490],
[0.5346, 0.5346, 0.5346, 0.5500, 0.2490, 0.5346, 0.5346, 0.9881, 0.2490],
[0.9881, 0.5500, 0.5500, 0.9881, 0.2490, 0.5346, 0.2490, 0.5346, 0.2490]])
Torch ind tensor([8, 4, 6, 5, 7, 1, 2, 0, 3])
Numpy ind tensor([8, 4, 6, 5, 7, 1, 2, 0, 3])
Torch result
tensor([[0.2490, 0.5500, 0.9881, 0.5346, 0.9881, 0.5346, 0.9881, 0.5500, 0.5500],
[0.2490, 0.2490, 0.5346, 0.5346, 0.9881, 0.5346, 0.5346, 0.5346, 0.5500],
[0.2490, 0.2490, 0.2490, 0.5346, 0.5346, 0.5500, 0.5500, 0.9881, 0.9881]])
Numpy result
tensor([[0.2490, 0.5500, 0.9881, 0.5346, 0.9881, 0.5346, 0.9881, 0.5500, 0.5500],
[0.2490, 0.2490, 0.5346, 0.5346, 0.9881, 0.5346, 0.5346, 0.5346, 0.5500],
[0.2490, 0.2490, 0.2490, 0.5346, 0.5346, 0.5500, 0.5500, 0.9881, 0.9881]])```