I want to compute cumulative sum of every k consecutive elements of a tensor. More precisely, suppose the tensor is [1,2,3,4,5,6,7,8,9], then the result should be [1,3,6,4,9,15,7,15,24]. Currently, I am doing this by using torch.cumsum and a for-loop, but I wonder if there is a more efficient way to do this without using a loop.
This is the best solution I have so far. Let me know if anyone comes up with something more efficient.
k=3
a = [1,2,3,4,5,6,7,8,9]
s = torch.split(a, k)
cumsum = torch.stack(s).cumsum(dim=1)
sum = torch.cat(torch.unbind(cumsum))
The following is, in essence, the same as your solution, but it might
be a bit more efficient to replace your cat / unbind / stack / split
processing with reshape() and flatten():