Hi,
PyTorch provides the scatter_add
operation, which can be very useful, but unfortunately is not deterministic. A deterministic mode for this operation is AFAIK not yet implemented (see this issue).
While the built-in scatter_add
is very general, I’m wondering if there are solutions for the special case where the index tensor is sorted. This could be a very relevant usecase where the index is just a segmentation/mask for variable-sized batches that were stacked together.
I’m thinking of something like this:
>>> batch_tensor = torch.tensor([1, 2, 3, 4, 5])
>>> index_tensor = torch.tensor([0, 0, 0, 1, 1])
# either
>>> scatter_add(batch_tensor, index_tensor)
tensor([6, 9])
# or
>>> torch.zeros(2).scatter_add(batch_tensor, index_tensor)
tensor([6, 9])
we have a batch with 2 elements, [1, 2, 3]
and [4, 5]
, and we’re doing a sum global pool.
It seems there already are thoughts about sorted indices out there (e.g. point 3 here), but I don’t have enough insight to assess the current status. I’m wondering:
- Is there already something like this in torch?
- Is there already something like this in cuda?
- If not, how straightforward is it to implement this?
Sure one can implement this directly in python by iterating over the index array, but I guess this would be painfully slow and there must be a better way that leverages the GPU.
I think a solution to this would have an immediate benefit for graph neural networks, but potentially also for other areas like CNNs with variable-sized images.
Stan.