Hi,
I was wondering what the equivalent function to tf.unsorted_segment_sum
(https://www.tensorflow.org/api_docs/python/tf/unsorted_segment_sum) is. I can come up with two alternatives:
- Use (sparse) matrix multiplication. Suppose
Y
has shape(M,D)
andX
has shape(N, D)
, we can do
Y = I @ X
where I
is a M
by N
sparse binary matrix. The issue with this implementation is the overhead of constructing I
.
- Use
scatter_add_
, like
Y = torch.zeros(M, D).scatter_add_(dim=0, index=I, other=X)
where I
is an index matrix of shape (N, D)
. The problem with this implementation is that scatter_add_
requires M >= N
Is there any better solution? Thanks!