Numpy.lexsort equivalent in pytorch

I am looking for numpy.lexsort equivalent function in pytorch. Thanks.

Example:

import numpy as np

a = np.array([[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.56 0.09 ]
[15. 25. 3.447 0.2 ]
[14. 25. 3.445 0.58 ]
[15. 25. 3.453 0. ]
[18. 26. 3.559 0. ]
[18. 26. 3.558 0. ]
[18. 26. 3.558 0. ]])

indices = np.lexsort((a[:,2], a[:,1], a[:, 0]))
b=a[indices]

print(b)

Output:

[[14. 25. 3.445 0.58 ]
[15. 25. 3.447 0.2 ]
[15. 25. 3.453 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.559 0. ]
[18. 25. 3.56 0.09 ]
[18. 26. 3.558 0. ]
[18. 26. 3.558 0. ]
[18. 26. 3.559 0. ]]

2 Likes

Yes, the pytorch document not include this

Did you figure out the solution?

I happened to run into this, I have implemented a lexsort based on stable sort (undocumented behavior) here.

Also, I give no guarantees, but I think you can use torch.unique:

def torch_lexsort(a, dim=-1):
    assert dim == -1  # Transpose if you want differently
    assert a.ndim == 2  # Not sure what is numpy behaviour with > 2 dim
    # To be consistent with numpy, we flip the keys (sort by last row first)
    a_unq, inv = torch.unique(a.flip(0), dim=dim, sorted=True, return_inverse=True)
    return torch.argsort(inv)

# Make random float vector with duplicates to test if it handles floating point well
vals = torch.rand(4)
a = vals[(torch.rand(3, 9) * 4).long()]

print(a)
ind = torch_lexsort(a)
ind_np = torch.from_numpy(np.lexsort(a.numpy()))
print("Torch ind", ind)
print("Numpy ind", ind)

print("Torch result")
print (a[:, ind])
print("Numpy result")
print (a[:, ind_np])

This gives

tensor([[0.5500, 0.5346, 0.9881, 0.5500, 0.5500, 0.5346, 0.9881, 0.9881, 0.2490],
        [0.5346, 0.5346, 0.5346, 0.5500, 0.2490, 0.5346, 0.5346, 0.9881, 0.2490],
        [0.9881, 0.5500, 0.5500, 0.9881, 0.2490, 0.5346, 0.2490, 0.5346, 0.2490]])
Torch ind tensor([8, 4, 6, 5, 7, 1, 2, 0, 3])
Numpy ind tensor([8, 4, 6, 5, 7, 1, 2, 0, 3])
Torch result
tensor([[0.2490, 0.5500, 0.9881, 0.5346, 0.9881, 0.5346, 0.9881, 0.5500, 0.5500],
        [0.2490, 0.2490, 0.5346, 0.5346, 0.9881, 0.5346, 0.5346, 0.5346, 0.5500],
        [0.2490, 0.2490, 0.2490, 0.5346, 0.5346, 0.5500, 0.5500, 0.9881, 0.9881]])
Numpy result
tensor([[0.2490, 0.5500, 0.9881, 0.5346, 0.9881, 0.5346, 0.9881, 0.5500, 0.5500],
        [0.2490, 0.2490, 0.5346, 0.5346, 0.9881, 0.5346, 0.5346, 0.5346, 0.5500],
        [0.2490, 0.2490, 0.2490, 0.5346, 0.5346, 0.5500, 0.5500, 0.9881, 0.9881]])```

Now that pytorch has stable argsort, I rewrote the linked solution by @wouter more concisely:

def lexsort(keys, dim=-1):
    if keys.ndim < 2:
        raise ValueError(f"keys must be at least 2 dimensional, but {keys.ndim=}.")
    if len(keys) == 0:
        raise ValueError(f"Must have at least 1 key, but {len(keys)=}.")
    
    idx = keys[0].argsort(dim=dim, stable=True)
    for k in keys[1:]:
        idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True))
    
    return idx

compared to the solution based on torch.unique by @wouter, this can handle arbitrary dim and also handles tensors with NaNs (there is a torch.unique bug that causes it to fail on inputs with NaNs when dim is not None, see torch.unique() nondeterministic behavior on nan inputs (on GPU) · Issue #76571 · pytorch/pytorch · GitHub and torch.unique output incorrect when tensor contains NaNs · Issue #95583 · pytorch/pytorch · GitHub. Until this bug is resolved, I don’t see anyway of getting rid of the for-loop while maintaining correctness on inputs with NaNs)