I’m encountering a confusing behavior when using advanced indexing to update values in a tensor. Here’s an example:
import torch
A = torch.zeros(2, 2)
Inds = torch.tensor([[0,0,1],[1,1,1]])
A[Inds.unbind()] += torch.tensor([1, 0, 1])
I expected A[0, 1]
to increment by 1, but its value remains 0
while the value of A[1,1]
is incremented by 1. At first I thought that this was related to the in-place operation. I tried to use an explicit addition as shown bellow:
A[Inds.unbind()] = A[Inds.unbind()] + torch.tensor([1, 0, 1])
After this operation, A[0, 1]
is still 0
whereas A[1,1]
equals 2 (which the corredct value since it was equal to 1 before).
My current fix is done using a loop over the values of Inds but I would like to know why does this happen and if there is a way to vecotrize this instead?