Memory efficent way to sum over a sequence of tensors

Hi, I want to implement a simple summation over a sequence of tensors. Previously, I implemented it using stack & sum as

torch.sum(torch.stack(seq, dim=0), dim=0)

However, this implementation will cause OOM for large tensors. For-loop implementation may also cause OOM since gradients are needed.

May I ask if there is any memory-efficient method to implement the summation over a sequence? Intuitively, I feel like this should just be a bigger sum, but I cannot find a solution. I also found this new wild nested Tensor feature. Would it support summation / mean in the future? Thanks!