Hi,
I need the function that does something like torch.Tensor.scatter_(dim, index, src)
.
x = torch.Tensor([[9,10,11,12]])
id = torch.Tensor([[4,3,4,7]])
out = torch.zeros(1, 8).scatter_(1, id, x)
When I run the code above, I got the output [0, 0, 0, 10, 11, 0, 0, 12]
.
While I expect the output to be [0, 0, 0, 10, 20, 0, 0, 12]
.
Note that the value in index 4 is 20 (9+11), not 11.
Does anyone know how to implement scatter that add the source value to the tensor, not just cover it?