Can you store a cummulative sum of weight updates in a memory efficient way without applying them?

Hi I’m doing a meta-learning AI project (using MAML), which requires you making multiple copies of one model then computing the loss, adapting (without backward()), then computing the loss again for each task.

The 3rd step requires aggregating all the losses (and by default: their compute graphs) into one loss for the meta-learner to optimize on. This apparently dramatically increases memory requirements such that it causes the program to crash almost automatically.

What I want to know is if it there is a memory efficient way to keep a cumulative sum of weight updates across meta-learning tasks so that I don’t end up with a massive compute graph by doing just loss1+loss2+… etc… In particular would using backward() multiple times achieve something like this?