Hi,

I was wondering what the equivalent function to `tf.unsorted_segment_sum`

(https://www.tensorflow.org/api_docs/python/tf/unsorted_segment_sum) is. I can come up with two alternatives:

- Use (sparse) matrix multiplication. Suppose
`Y`

has shape`(M,D)`

and`X`

has shape`(N, D)`

, we can do

```
Y = I @ X
```

where `I`

is a `M`

by `N`

sparse binary matrix. The issue with this implementation is the overhead of constructing `I`

.

- Use
`scatter_add_`

, like

```
Y = torch.zeros(M, D).scatter_add_(dim=0, index=I, other=X)
```

where `I`

is an index matrix of shape `(N, D)`

. The problem with this implementation is that `scatter_add_`

requires `M >= N`

Is there any better solution? Thanks!