I am trying to come up with an efficient way of summing uneven chunks of a tensor. The answer has the same size as the inputs, which include a data tensor and an index tensor. For example,
array = torch.Tensor([[0, 1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12, 13],
[14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27]])
index = torch.LongTensor([[0, 0, 1, 1, 1, 2, 2],
[0, 1, 1, 2, 2, 2, 2],
[0, 0, 0, 1, 1, 1, 2]])
# index can also be written as
# index = torch.LongTensor([2, 3, 2], [1, 2, 4], [3, 3, 2])
answer = torch.Tensor([[1, 1, 9, 9, 9, 11, 11],
[7, 17, 17, 46, 46, 46, 46],
[66, 66, 66, 75, 75, 75, 27]])
Here, every element in the first row of x
is replaced by the sum of the chunk it belongs to, i.e. answer[0][2] = answer[0][3] = answer[0][4] = array[0][2:4].sum()
.
I tried to split up the data tensor and sum up each part, but copying the sum of each part to the answer is very slow if not parallelized, which is unavoidable because each part has different size. I’m thinking of constructing a block diagonal ByteTensor (of size 3 x 7 x 7 in this example) from the index tensor, so I can do batch matrix multiplication. Do you know if there’s an efficient way to write it? Thank you.