How to implement scatter add?

Hi,

I need the function that does something like torch.Tensor.scatter_(dim, index, src).

x = torch.Tensor([[9,10,11,12]])
id = torch.Tensor([[4,3,4,7]])
out = torch.zeros(1, 8).scatter_(1, id, x)

When I run the code above, I got the output [0, 0, 0, 10, 11, 0, 0, 12].
While I expect the output to be [0, 0, 0, 10, 20, 0, 0, 12].
Note that the value in index 4 is 20 (9+11), not 11.

Does anyone know how to implement scatter that add the source value to the tensor, not just cover it?

2 Likes

There’s a function called scatter_add_ that was added into master recently. It’ll be in the next release.
If you want it immediately, you can compile pytorch from source: https://github.com/pytorch/pytorch#from-source

1 Like