Let’s say I have three tensors a, b, c and three hyperparameters bs, L, V
a : dtype torch.long, size (bs, L), all values between 0 and V-1
b : dtype torch.float, size (bs, L), all values between 0 and 1
c : drype torch.float, size (bs, V), initialized as all 0
Now I want to assign values to c based on a and b like this
for i in range(bs): for j in range(L): c[i, a[i, j] ] += b[i, j]
In the final row I use
+= because there can be duplicate values in each row of
a. That’s also why I am not sure if I can use
Does anybody know if there is an API that can do this efficiently? Like with one for loop or even no for loop?
Thank you very much!