Hi,
Due to some reason, I need to do the operation below.
labels = torch.tensor([[1, 0, 0, 1], [1,0, 1, 1]])
counts = torch.tensor([[2, 2], [1, 3]])
F = torch.zeros(2, 2, 4)
for i in range(2):
for j in range(2):
F[i, j, labels[i]==j] = counts[i][j]
print(F)
Out:
tensor([[[0., 2., 2., 0.],
[2., 0., 0., 2.]],
[[0., 1., 0., 0.],
[3., 0., 3., 3.]]])
The double for loop lead to the slow operation. Is there any torch function can make it faster?
Thanks!