Split and sum with a more efficient method

Is there a more efficient way to accomplish the following? (for general X, sizes)

sizes = [3, 7, 5, 9]
X = torch.ones(sum(sizes))
Y = torch.tensor([s.sum() for s in torch.split(X, sizes)])


For instance with numpy I can do the following

indices = np.cumsum([0]+sizes)[:-1]
Y = np.add.reduceat(X, indices.tolist())

Is there a pytorch equivalent for this?

Not sure how the performance difference would be, but you could create an index tensor and use scatter_add_:

idx = torch.cat([torch.tensor([i]*s) for i, s in enumerate(sizes)])
torch.zeros(len(sizes)).scatter_add_(0, idx, X)