# How to sort tensor by given order?

Hi,

I’m looking for a function to sort the value of a 2D tensor by a given order.

For example, A is a 2D tensor ([1, 0, 2, 2, 1], [0, 2, 1, 2, 0]), and I want to sort A following an order tensor B as ([3, 2, 4, 1, 0], [2, 1, 4, 3, 3]). So the sorted A should be ([2, 2, 1, 0, 1], [1, 2, 0, 2, 2]). There can be repeating elements in the same row of B and each row of A is sorted by the order in corresponding row of B.

Is there any function I can directly apply to get the sorted A?

Any suggestion is appreciated!

Hello,

As far as I know, there is no such function to get a 2D tensor sorted from another 2D tensor. However, with a simple `for` loop and indexing, it is possible to achieve the desired result:

``````    a = torch.tensor([[1, 0, 2, 2, 1], [0, 2, 1, 2, 0]], dtype=torch.int32)
# index tensor (b) must be of type long, byte or bool
b = torch.tensor([[3, 2, 4, 1, 0], [2, 1, 4, 3, 3]], dtype=torch.long)
for i in range(b.shape):
# change the order of the ith tensor of a with the indexes of b[i]
c = a[i][b[i]]
print(c)
# tensor([2, 2, 1, 0, 1], dtype=torch.int32)
# tensor([1, 2, 0, 2, 2], dtype=torch.int32)
``````

Don’t know if this helps!

1 Like

Tensorized version can be implemented as follows:

``````def smart_sort(x, permutation):
d1, d2 = x.size()
ret = x[
torch.arange(d1).unsqueeze(1).repeat((1, d2)).flatten(),
permutation.flatten()
].view(d1, d2)
return ret
``````

And it sorts as you want:

``````a = torch.tensor([[1, 0, 2, 2, 1], [0, 2, 1, 2, 0]])
b = torch.tensor([[3, 2, 4, 1, 0], [2, 1, 4, 3, 3]])
smart_sort(a, b)

output: tensor([[2, 2, 1, 0, 1],
[1, 2, 0, 2, 2]])
``````
1 Like

Thanks for the help!

you can try this,also

``````def sort_tensor(X, dim=-1):
d1, d2 = X.size()
ten, ind=torch.sort(X, dim=dim)
if dim == -1:
return X[torch.arange(d1).unsqueeze(1).repeat((1, d2)), ind]
elif dim == 0:
return X[ind, torch.arange(d2)]
else:
raise IndexError("wrong dim parameter")
``````

output:

``````t = torch.tensor([[23, 45, 66, 100],
[90, 56, 8,132],
[34, 90, 88, 200],
[250, 15, 90, 5]])

sort_tensor(t, dim=-1)
output:
tensor([[ 23,  45,  66, 100],
[  8,  56,  90, 132],
[ 34,  88,  90, 200],
[  5,  15,  90, 250]])

sort_tensor(t, dim=-0)
output:
tensor([[ 23,  15,   8,   5],
[ 34,  45,  66, 100],
[ 90,  56,  88, 132],
[250,  90,  90, 200]])

``````
1 Like