Sort index according to label permutation

Hello !

I have a batch of tensor that represent a permutation of label called label_perm.
I have another batch tensor called node_label that give the label of each node.
I want to have a batch tensor of node index called res that follow the order of the label_perm :

for example :

label_perm = [[0,1,2]] #the label 0 is the first, the label 1 come after and finally the label 2
node_label = [[1,2,1,0,1,1,2]] # each node has a label
res = [[3,0,2,4,5,1,6]]
The first label in label_perm is 0 and only the node 3 has label 0. Then the second label is 1 so the node 0,2,4,5 have label 1 and so on …
Do you know how to do this in torch with an arbitrary batch size ?

Thanks you very much ! :slight_smile:

This isn’t a clean solution, but it might at least help you get started.

Defining Tensors

label_perm = torch.tensor([[0,1,2]])
node_label = torch.tensor([[1,2,1,0,1,1,2]])

Broadcasting

Assume that the shape of node_label is (1, m) and the shape of label_perm is (1, n). We want to make the shape of label_perm to be (n,1) (using reshape) to allow broadcasting. Broadcasting enables us to create a tensor value_present of shape n x m, where result[n,m] is True if label_perm[n] == node_label[m].

value_present = (node_label == label_perm.reshape(-1, 1))
"""
value_present =
 tensor([[False, False, False,  True, False, False, False],
        [ True, False,  True, False,  True,  True, False],
        [False,  True, False, False, False, False,  True]])
"""

Now, get the indices of all True values in value_present and store it in indices. The second column of indices should yield the final tensor.

indices = value_present.nonzero()
result = indices[:, 1]
"""
indices = 
  tensor([[0, 3],
          [1, 0],
          [1, 2],
          [1, 4],
          [1, 5],
          [2, 1],
          [2, 6]])
result = tensor([3, 0, 2, 4, 5, 1, 6])
"""

You can modify this to work for arbitrary batch sizes.

1 Like

Thank you very much for you answer, it really do the job !

I found another solution by myself using one_hot_encoding and bmm :slight_smile:

label_perm = torch.tensor([[0,1,2],[0,1,2]]) 
node_label = torch.tensor([[1,2,1,0,1,1,2],[1,2,1,1,1,1,2]])
res = torch.tensor([[3,0,2,4,5,1,6],[0,2,3,4,5,1,6]])
a = F.one_hot(label_perm,num_classes=3)
b=F.one_hot(node_label,num_classes=3)
z = torch.bmm(a,b.transpose(-1,-2))
k = torch.argwhere(z>0)
perm = k[:,-1].view(2,-1)

Here an example :slight_smile:

1 Like