# Pytorch equivalent to `tf.unsorted_segment_sum`

Hi,

I was wondering what the equivalent function to `tf.unsorted_segment_sum` (https://www.tensorflow.org/api_docs/python/tf/unsorted_segment_sum) is. I can come up with two alternatives:

1. Use (sparse) matrix multiplication. Suppose `Y` has shape `(M,D)` and `X` has shape `(N, D)`, we can do
``````Y = I @ X
``````

where `I` is a `M` by `N` sparse binary matrix. The issue with this implementation is the overhead of constructing `I`.

1. Use `scatter_add_`, like
``````Y = torch.zeros(M, D).scatter_add_(dim=0, index=I, other=X)
``````

where `I` is an index matrix of shape `(N, D)`. The problem with this implementation is that `scatter_add_` requires `M >= N`

Is there any better solution? Thanks!

Would `scatter_add` with `dim=1` work?

``````
index = torch.tensor([[0, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 0]])
data = torch.tensor([[5., 1., 7., 2., 3., 4.],
[5., 1., 7., 2., 3., 4.]])

torch.zeros(2, 2).scatter_add(1, index, data)
> tensor([[  9.,  13.],
[ 13.,   9.]])
``````
1 Like

Thanks so much for the solution!

1 Like

Thank you, this works. But from what I understand, this should require a lot of memory, since the source data array will be replicated x times if I x different selected sums are taken? For example, the original array is replicated thrice here:

``````index = torch.tensor([[0, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 0],
[0, 0, 1, 1, 1, 1],])

data = torch.tensor([[5., 1., 7., 2., 3., 4.],
[5., 1., 7., 2., 3., 4.]])
[5., 1., 7., 2., 3., 4.]])

torch.zeros(3, 3).scatter_add(1, index, data)
> tensor([[ 9., 13.,  0.],
[13.,  9.,  0.],
[ 6., 16.,  0.]])
``````

You could use `expand` to change the metadata without triggering a copy:

``````index = torch.tensor([[0, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 0]])
data = torch.tensor([[5., 1., 7., 2., 3., 4.]])

torch.zeros(2, 2).scatter_add(1, index, data.expand(2, -1))
``````
1 Like

I have a similar question. My data is a 3x3 matrix with values (`src`), and the index contains column ids (`index`), indicating to which column each element of `src` belongs.

``````import torch

src = torch.tensor([[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]], dtype=torch.int64)

index = torch.tensor([[0, 0, 1], [0, 0, 1], [0, 0, 1]])
``````

The result of the unsorted segment sum should be `[9, 8, 0]` (since there are 3 columns or segments). The sum over the first column (with id 0) is 9, the sum over the second column (with id 1) is 8, and the sum over the third column is 0 (since index does not contain any id 2). How can I achieve this with `scatter_add`?