Numpy.lexsort equivalent in pytorch

I am looking for numpy.lexsort equivalent function in pytorch. Thanks.

Example:

import numpy as np

a = np.array([[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.56 0.09 ]
[15. 25. 3.447 0.2 ]
[14. 25. 3.445 0.58 ]
[15. 25. 3.453 0. ]
[18. 26. 3.559 0. ]
[18. 26. 3.558 0. ]
[18. 26. 3.558 0. ]])

indices = np.lexsort((a[:,2], a[:,1], a[:, 0]))
b=a[indices]

print(b)

Output:

[[14. 25. 3.445 0.58 ]
[15. 25. 3.447 0.2 ]
[15. 25. 3.453 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.56 0.09 ]
[18. 26. 3.558 0. ]
[18. 26. 3.558 0. ]
[18. 26. 3.559 0. ]]

2 Likes

Yes, the pytorch document not include this

Did you figure out the solution?

I happened to run into this, I have implemented a lexsort based on stable sort (undocumented behavior) here.

Also, I give no guarantees, but I think you can use torch.unique:

def torch_lexsort(a, dim=-1):
    assert dim == -1  # Transpose if you want differently
    assert a.ndim == 2  # Not sure what is numpy behaviour with > 2 dim
    # To be consistent with numpy, we flip the keys (sort by last row first)
    a_unq, inv = torch.unique(a.flip(0), dim=dim, sorted=True, return_inverse=True)
    return torch.argsort(inv)

# Make random float vector with duplicates to test if it handles floating point well
vals = torch.rand(4)
a = vals[(torch.rand(3, 9) * 4).long()]

print(a)
ind = torch_lexsort(a)
ind_np = torch.from_numpy(np.lexsort(a.numpy()))
print("Torch ind", ind)
print("Numpy ind", ind)

print("Torch result")
print (a[:, ind])
print("Numpy result")
print (a[:, ind_np])

This gives

tensor([[0.5500, 0.5346, 0.9881, 0.5500, 0.5500, 0.5346, 0.9881, 0.9881, 0.2490],
        [0.5346, 0.5346, 0.5346, 0.5500, 0.2490, 0.5346, 0.5346, 0.9881, 0.2490],
        [0.9881, 0.5500, 0.5500, 0.9881, 0.2490, 0.5346, 0.2490, 0.5346, 0.2490]])
Torch ind tensor([8, 4, 6, 5, 7, 1, 2, 0, 3])
Numpy ind tensor([8, 4, 6, 5, 7, 1, 2, 0, 3])
Torch result
tensor([[0.2490, 0.5500, 0.9881, 0.5346, 0.9881, 0.5346, 0.9881, 0.5500, 0.5500],
        [0.2490, 0.2490, 0.5346, 0.5346, 0.9881, 0.5346, 0.5346, 0.5346, 0.5500],
        [0.2490, 0.2490, 0.2490, 0.5346, 0.5346, 0.5500, 0.5500, 0.9881, 0.9881]])
Numpy result
tensor([[0.2490, 0.5500, 0.9881, 0.5346, 0.9881, 0.5346, 0.9881, 0.5500, 0.5500],
        [0.2490, 0.2490, 0.5346, 0.5346, 0.9881, 0.5346, 0.5346, 0.5346, 0.5500],
        [0.2490, 0.2490, 0.2490, 0.5346, 0.5346, 0.5500, 0.5500, 0.9881, 0.9881]])```