How to integrate outputsacross steps and then compute loss?

Hi, in my training task, I expect a large batch for a better gradient estimate and the problem is the memory. I know one common method to enlarge the batch size is to use the gradient accumulation step. But that doesn’t work for me since these two are different for my loss:

  1. Compute the loss according to the output of several steps and add them together, which is gradient accumulation.
  2. Integrate the outputs across several steps and then compute the loss.

The wanted way to compute the loss is the second one. Therefore, I am thinking about whether can I save the results of multiple steps and then concatenate them to compute the loss. Two things I am worried about: First, it may take quite an amount of memory to save multiple graphs. Second, even though the memory works fine, the outputs used to compute the loss comes from different graphs. So I suspect the results I get are still computing the loss separately.

In a word, the correct computation of my loss requires a large batch. But due to the memory, I have to run small batches several times. Is there any solution that I can integrate them so that they behave like running in a large batch?:grinning_face_with_smiling_eyes::thinking: