# Sort index according to label permutation

Hello !

I have a batch of tensor that represent a permutation of label called label_perm.
I have another batch tensor called node_label that give the label of each node.
I want to have a batch tensor of node index called res that follow the order of the label_perm :

for example :

label_perm = [[0,1,2]] #the label 0 is the first, the label 1 come after and finally the label 2
node_label = [[1,2,1,0,1,1,2]] # each node has a label
res = [[3,0,2,4,5,1,6]]
The first label in label_perm is 0 and only the node 3 has label 0. Then the second label is 1 so the node 0,2,4,5 have label 1 and so on …
Do you know how to do this in torch with an arbitrary batch size ?

Thanks you very much ! This isn’t a clean solution, but it might at least help you get started.

Defining Tensors

``````label_perm = torch.tensor([[0,1,2]])
node_label = torch.tensor([[1,2,1,0,1,1,2]])
``````

Assume that the shape of `node_label` is `(1, m)` and the shape of `label_perm` is `(1, n)`. We want to make the shape of `label_perm` to be `(n,1)` (using `reshape`) to allow broadcasting. Broadcasting enables us to create a tensor `value_present` of shape `n x m`, where `result[n,m]` is `True` if `label_perm[n] == node_label[m]`.

``````value_present = (node_label == label_perm.reshape(-1, 1))
"""
value_present =
tensor([[False, False, False,  True, False, False, False],
[ True, False,  True, False,  True,  True, False],
[False,  True, False, False, False, False,  True]])
"""
``````

Now, get the indices of all `True` values in `value_present` and store it in `indices`. The second column of `indices` should yield the final tensor.

``````indices = value_present.nonzero()
result = indices[:, 1]
"""
indices =
tensor([[0, 3],
[1, 0],
[1, 2],
[1, 4],
[1, 5],
[2, 1],
[2, 6]])
result = tensor([3, 0, 2, 4, 5, 1, 6])
"""
``````

You can modify this to work for arbitrary batch sizes.

1 Like

Thank you very much for you answer, it really do the job !

I found another solution by myself using one_hot_encoding and bmm ``````label_perm = torch.tensor([[0,1,2],[0,1,2]])
node_label = torch.tensor([[1,2,1,0,1,1,2],[1,2,1,1,1,1,2]])
res = torch.tensor([[3,0,2,4,5,1,6],[0,2,3,4,5,1,6]])
a = F.one_hot(label_perm,num_classes=3)
b=F.one_hot(node_label,num_classes=3)
z = torch.bmm(a,b.transpose(-1,-2))
k = torch.argwhere(z>0)
perm = k[:,-1].view(2,-1)

``````

Here an example 1 Like