Pytorch equivalent to `tf.unsorted_segment_sum`


(handesy) #1

Hi,

I was wondering what the equivalent function to tf.unsorted_segment_sum (https://www.tensorflow.org/api_docs/python/tf/unsorted_segment_sum) is. I can come up with two alternatives:

  1. Use (sparse) matrix multiplication. Suppose Y has shape (M,D) and X has shape (N, D), we can do
Y = I @ X

where I is a M by N sparse binary matrix. The issue with this implementation is the overhead of constructing I.

  1. Use scatter_add_, like
Y = torch.zeros(M, D).scatter_add_(dim=0, index=I, other=X)

where I is an index matrix of shape (N, D). The problem with this implementation is that scatter_add_ requires M >= N

Is there any better solution? Thanks!


#2

Would scatter_add with dim=1 work?


index = torch.tensor([[0, 0, 1, 1, 0, 1],
                      [1, 1, 0, 0, 1, 0]])
data = torch.tensor([[5., 1., 7., 2., 3., 4.],
                     [5., 1., 7., 2., 3., 4.]])

torch.zeros(2, 2).scatter_add(1, index, data)
> tensor([[  9.,  13.],
          [ 13.,   9.]])

(handesy) #3

Thanks so much for the solution!