Pytorch equivalent to `tf.unsorted_segment_sum`

Hi,

I was wondering what the equivalent function to tf.unsorted_segment_sum (https://www.tensorflow.org/api_docs/python/tf/unsorted_segment_sum) is. I can come up with two alternatives:

  1. Use (sparse) matrix multiplication. Suppose Y has shape (M,D) and X has shape (N, D), we can do
Y = I @ X

where I is a M by N sparse binary matrix. The issue with this implementation is the overhead of constructing I.

  1. Use scatter_add_, like
Y = torch.zeros(M, D).scatter_add_(dim=0, index=I, other=X)

where I is an index matrix of shape (N, D). The problem with this implementation is that scatter_add_ requires M >= N

Is there any better solution? Thanks!

Would scatter_add with dim=1 work?


index = torch.tensor([[0, 0, 1, 1, 0, 1],
                      [1, 1, 0, 0, 1, 0]])
data = torch.tensor([[5., 1., 7., 2., 3., 4.],
                     [5., 1., 7., 2., 3., 4.]])

torch.zeros(2, 2).scatter_add(1, index, data)
> tensor([[  9.,  13.],
          [ 13.,   9.]])
1 Like

Thanks so much for the solution!

1 Like

Thank you, this works. But from what I understand, this should require a lot of memory, since the source data array will be replicated x times if I x different selected sums are taken? For example, the original array is replicated thrice here:

index = torch.tensor([[0, 0, 1, 1, 0, 1],
                      [1, 1, 0, 0, 1, 0],
                      [0, 0, 1, 1, 1, 1],])

data = torch.tensor([[5., 1., 7., 2., 3., 4.],
                     [5., 1., 7., 2., 3., 4.]])
                     [5., 1., 7., 2., 3., 4.]])

torch.zeros(3, 3).scatter_add(1, index, data)
> tensor([[ 9., 13.,  0.],
          [13.,  9.,  0.],
          [ 6., 16.,  0.]])

You could use expand to change the metadata without triggering a copy:

index = torch.tensor([[0, 0, 1, 1, 0, 1],
                      [1, 1, 0, 0, 1, 0]])
data = torch.tensor([[5., 1., 7., 2., 3., 4.]])

torch.zeros(2, 2).scatter_add(1, index, data.expand(2, -1))
1 Like

I have a similar question. My data is a 3x3 matrix with values (src), and the index contains column ids (index), indicating to which column each element of src belongs.

import torch

src = torch.tensor([[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]], dtype=torch.int64)

index = torch.tensor([[0, 0, 1], [0, 0, 1], [0, 0, 1]])

The result of the unsorted segment sum should be [9, 8, 0] (since there are 3 columns or segments). The sum over the first column (with id 0) is 9, the sum over the second column (with id 1) is 8, and the sum over the third column is 0 (since index does not contain any id 2). How can I achieve this with scatter_add?