PyTorch provides the
scatter_add operation, which can be very useful, but unfortunately is not deterministic. A deterministic mode for this operation is AFAIK not yet implemented (see this issue).
While the built-in
scatter_add is very general, I’m wondering if there are solutions for the special case where the index tensor is sorted. This could be a very relevant usecase where the index is just a segmentation/mask for variable-sized batches that were stacked together.
I’m thinking of something like this:
>>> batch_tensor = torch.tensor([1, 2, 3, 4, 5]) >>> index_tensor = torch.tensor([0, 0, 0, 1, 1]) # either >>> scatter_add(batch_tensor, index_tensor) tensor([6, 9]) # or >>> torch.zeros(2).scatter_add(batch_tensor, index_tensor) tensor([6, 9])
we have a batch with 2 elements,
[1, 2, 3] and
[4, 5], and we’re doing a sum global pool.
It seems there already are thoughts about sorted indices out there (e.g. point 3 here), but I don’t have enough insight to assess the current status. I’m wondering:
- Is there already something like this in torch?
- Is there already something like this in cuda?
- If not, how straightforward is it to implement this?
Sure one can implement this directly in python by iterating over the index array, but I guess this would be painfully slow and there must be a better way that leverages the GPU.
I think a solution to this would have an immediate benefit for graph neural networks, but potentially also for other areas like CNNs with variable-sized images.