How to sort tensor by given order?

Hi,

I’m looking for a function to sort the value of a 2D tensor by a given order.

For example, A is a 2D tensor ([1, 0, 2, 2, 1], [0, 2, 1, 2, 0]), and I want to sort A following an order tensor B as ([3, 2, 4, 1, 0], [2, 1, 4, 3, 3]). So the sorted A should be ([2, 2, 1, 0, 1], [1, 2, 0, 2, 2]). There can be repeating elements in the same row of B and each row of A is sorted by the order in corresponding row of B.

Is there any function I can directly apply to get the sorted A?

Any suggestion is appreciated!

Hello,

As far as I know, there is no such function to get a 2D tensor sorted from another 2D tensor. However, with a simple for loop and indexing, it is possible to achieve the desired result:

    a = torch.tensor([[1, 0, 2, 2, 1], [0, 2, 1, 2, 0]], dtype=torch.int32)
    # index tensor (b) must be of type long, byte or bool
    b = torch.tensor([[3, 2, 4, 1, 0], [2, 1, 4, 3, 3]], dtype=torch.long)
    for i in range(b.shape[0]):
        # change the order of the ith tensor of a with the indexes of b[i]
        c = a[i][b[i]]
        print(c)
    # tensor([2, 2, 1, 0, 1], dtype=torch.int32)
    # tensor([1, 2, 0, 2, 2], dtype=torch.int32)

Don’t know if this helps!

1 Like

Tensorized version can be implemented as follows:

def smart_sort(x, permutation):
    d1, d2 = x.size()
    ret = x[
        torch.arange(d1).unsqueeze(1).repeat((1, d2)).flatten(),
        permutation.flatten()
    ].view(d1, d2)
    return ret

And it sorts as you want:

a = torch.tensor([[1, 0, 2, 2, 1], [0, 2, 1, 2, 0]])
b = torch.tensor([[3, 2, 4, 1, 0], [2, 1, 4, 3, 3]])
smart_sort(a, b)

output: tensor([[2, 2, 1, 0, 1],
               [1, 2, 0, 2, 2]])
1 Like

Thanks for the help!

Thank you for your advice!

you can try this,also

def sort_tensor(X, dim=-1):
  d1, d2 = X.size()
  ten, ind=torch.sort(X, dim=dim)
  if dim == -1:
    return X[torch.arange(d1).unsqueeze(1).repeat((1, d2)), ind]
  elif dim == 0:
    return X[ind, torch.arange(d2)]
  else:
    raise IndexError("wrong dim parameter")

output:

t = torch.tensor([[23, 45, 66, 100], 
                [90, 56, 8,132],
                [34, 90, 88, 200],
                [250, 15, 90, 5]])

sort_tensor(t, dim=-1)
output:
tensor([[ 23,  45,  66, 100],
        [  8,  56,  90, 132],
        [ 34,  88,  90, 200],
        [  5,  15,  90, 250]])

sort_tensor(t, dim=-0)
output:
tensor([[ 23,  15,   8,   5],
        [ 34,  45,  66, 100],
        [ 90,  56,  88, 132],
        [250,  90,  90, 200]])
                  
1 Like