Specialized scatter operations for batched data


PyTorch provides the scatter_add operation, which can be very useful, but unfortunately is not deterministic. A deterministic mode for this operation is AFAIK not yet implemented (see this issue).

While the built-in scatter_add is very general, I’m wondering if there are solutions for the special case where the index tensor is sorted. This could be a very relevant usecase where the index is just a segmentation/mask for variable-sized batches that were stacked together.

I’m thinking of something like this:

>>> batch_tensor = torch.tensor([1, 2, 3, 4, 5])
>>> index_tensor = torch.tensor([0, 0, 0, 1, 1])

# either
>>> scatter_add(batch_tensor, index_tensor)
tensor([6, 9])

# or
>>> torch.zeros(2).scatter_add(batch_tensor, index_tensor)
tensor([6, 9])

we have a batch with 2 elements, [1, 2, 3] and [4, 5], and we’re doing a sum global pool.

It seems there already are thoughts about sorted indices out there (e.g. point 3 here), but I don’t have enough insight to assess the current status. I’m wondering:

  • Is there already something like this in torch?
  • Is there already something like this in cuda?
  • If not, how straightforward is it to implement this?

Sure one can implement this directly in python by iterating over the index array, but I guess this would be painfully slow and there must be a better way that leverages the GPU.

I think a solution to this would have an immediate benefit for graph neural networks, but potentially also for other areas like CNNs with variable-sized images.


1 Like