1D `scatter_add` requires `len(indices)` < `len(values)`

Hi, I would like to implement the following in PyTorch:

for i in range(indices.shape[0]):
    values[indices[i]] += 1

Both indices and values are 1-D tensors. I found scatter_add which perfectly meets the need:

values.scatter_add_(0, indices, torch.ones_like(values))

However, I found this requires that indices must be shorter than values, otherwise, it will throw an error complaining that indices length too long.

To reproduce, run the following code:

import torch
values = torch.zeros(5)
indices = torch.LongTensor([0, 3, 3])
print(values.scatter_add(0, indices, torch.ones_like(values))) # return tensor([1, 0, 0, 2, 0])

indices = torch.LongTensor([0, 3, 3, 0, 0, 0])
print(values.scatter_add(0, indices, torch.ones_like(values)) # throws error "Expected index [6] to be smaller size than src [5] and to be smaller than tensor [5] apart from dimension 0"

Is it designed in this way, or could there something wrong in the dimension check for this function? Also, is there any way around it?


you are looking for index_add, not scatter_add

When I replace scatter_add with index_add, it complains that the number of indices should be equal to source:size(dim). In other words, the constraint is even stricter.

I guess the problem is, it may be reasonable for index_copy, scatter to have size constraints in indices since one should not copy-paste to the same location twice. But for index_add and scatter_add, it does not quite make sense to have such constraints. My demo is a simplest example though it cannot work.

no… index_add_ is exactly what you need. why do you do ones_like(values)? can you look at the docs before complaining?

>>> x = torch.zeros(5)
>>> x.index_add_(0, torch.tensor([0, 0, 4, 1]), torch.ones(4))
tensor([2., 1., 0., 0., 1.])

Thank you for your explanation. I did read the docs before posting. The error msg “*** RuntimeError: invalid argument 4: Number of indices should be equal to source:size(dim)” gave me the wrong impression that the size of indices should match the size of “source” instead of “values”.

1 Like