Hi, I want to implement a simple summation over a sequence of tensors. Previously, I implemented it using stack & sum as
torch.sum(torch.stack(seq, dim=0), dim=0)
However, this implementation will cause OOM for large tensors. For-loop implementation may also cause OOM since gradients are needed.
May I ask if there is any memory-efficient method to implement the summation over a sequence? Intuitively, I feel like this should just be a bigger sum, but I cannot find a solution. I also found this new wild nested Tensor feature. Would it support summation / mean in the future? Thanks!