Sum along axis with bins

I am trying to sum a tensor along a single dimension while grouping outputs by a separate label tensor.
The functionality I’m looking for is similar to torch.bincount if it could be applied along a single dimension. Is there a torch function that has this functionality?
E.g.

>>> a
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
>>> b
tensor([0, 0, 1, 2])
# And I want to produce the following.
tensor([[ 5,  7,  9],
        [ 7,  8,  9],
        [10, 11, 12]])

Thanks for any help!

You can use Tensor.index_add (or here its inplace version) for this:

a = torch.arange(1, 13.).view(4, 3)
b = torch.tensor([0, 0, 1, 2])

c = torch.zeros(3, 3)
c.index_add_(0, b, a)
print(c)

Best regards

Thomas

2 Likes