Faster index the tensor

Hi,

Due to some reason, I need to do the operation below.


labels = torch.tensor([[1, 0, 0, 1], [1,0, 1, 1]])
counts = torch.tensor([[2, 2], [1, 3]])
F = torch.zeros(2, 2, 4)
for i in range(2):
    for j in range(2):
        F[i, j, labels[i]==j] = counts[i][j]
print(F)

Out:

tensor([[[0., 2., 2., 0.],
         [2., 0., 0., 2.]],

        [[0., 1., 0., 0.],
         [3., 0., 3., 3.]]])

The double for loop lead to the slow operation. Is there any torch function can make it faster?
Thanks!

I solved this problem by writing a custom cuda extension. :wink:

Hi there, I ran into the same problem. Do you mind sharing your solution?