Let’s say I have three tensors a, b, c and three hyperparameters bs, L, V
tensor a
: dtype torch.long, size (bs, L), all values between 0 and V-1
tensor b
: dtype torch.float, size (bs, L), all values between 0 and 1
tensor c
: drype torch.float, size (bs, V), initialized as all 0
Now I want to assign values to c based on a and b like this
for i in range(bs):
for j in range(L):
c[i, a[i, j] ] += b[i, j]
In the final row I use +=
because there can be duplicate values in each row of a
. That’s also why I am not sure if I can use scatter_
…
Does anybody know if there is an API that can do this efficiently? Like with one for loop or even no for loop?
Thank you very much!