Summing Uneven Chunks

I am trying to come up with an efficient way of summing uneven chunks of a tensor. The answer has the same size as the inputs, which include a data tensor and an index tensor. For example,

array = torch.Tensor([[0, 1, 2, 3, 4, 5, 6],
                      [7, 8, 9, 10, 11, 12, 13],
                      [14, 15, 16, 17, 18, 19, 20],
                      [21, 22, 23, 24, 25, 26, 27]])

index = torch.LongTensor([[0, 0, 1, 1, 1, 2, 2],
                          [0, 1, 1, 2, 2, 2, 2],
                          [0, 0, 0, 1, 1, 1, 2]])

# index can also be written as
# index = torch.LongTensor([2, 3, 2], [1, 2, 4], [3, 3, 2])

answer = torch.Tensor([[1, 1, 9, 9, 9, 11, 11],
                       [7, 17, 17, 46, 46, 46, 46],
                       [66, 66, 66, 75, 75, 75, 27]])

Here, every element in the first row of x is replaced by the sum of the chunk it belongs to, i.e. answer[0][2] = answer[0][3] = answer[0][4] = array[0][2:4].sum().

I tried to split up the data tensor and sum up each part, but copying the sum of each part to the answer is very slow if not parallelized, which is unavoidable because each part has different size. I’m thinking of constructing a block diagonal ByteTensor (of size 3 x 7 x 7 in this example) from the index tensor, so I can do batch matrix multiplication. Do you know if there’s an efficient way to write it? Thank you.

It looks like your index is missing a row or your array has one additional row (row2 based on the result).
A combination of scatter_add_ and gather would work:

array = torch.tensor([[0., 1, 2, 3, 4, 5, 6],
                      [7, 8, 9, 10, 11, 12, 13],
                      [21, 22, 23, 24, 25, 26, 27]])

index = torch.LongTensor([[0, 0, 1, 1, 1, 2, 2],
                          [0, 1, 1, 2, 2, 2, 2],
                          [0, 0, 0, 1, 1, 1, 2]])

x = torch.zeros(3, 3).scatter_add_(1, index, array)
x.gather(1, index)
> tensor([[  1.,   1.,   9.,   9.,   9.,  11.,  11.],
          [  7.,  17.,  17.,  46.,  46.,  46.,  46.],
          [ 66.,  66.,  66.,  75.,  75.,  75.,  27.]])
3 Likes

You’re absolutely right, and your solution works perfectly!

1 Like

Is it possible to implement this in pytorch 0.4.0? (scatter_add_ is unavailable)

Would it be possible to update to 0.4.1 or in the best case to the latest version?
There are a lot of new features and performance improvements, so I would definitely recommend it.
If you are somehow forced to use 0.4.0, we could come up with a solution, which might be slower than the original scatter_add_ one.