Let’s say I have three tensors a, b, c and three hyperparameters bs, L, V
tensor a : dtype torch.long, size (bs, L), all values between 0 and V-1
tensor b : dtype torch.float, size (bs, L), all values between 0 and 1
tensor c : drype torch.float, size (bs, V), initialized as all 0
Now I want to assign values to c based on a and b like this
for i in range(bs):
for j in range(L):
c[i, a[i, j] ] += b[i, j]
In the final row I use += because there can be duplicate values in each row of a. That’s also why I am not sure if I can use scatter_…
Does anybody know if there is an API that can do this efficiently? Like with one for loop or even no for loop?
Thank you very much!