Accumulating grads with DDP with one forward pass and multiple backwards

In my training I need to perform a forward pass, followed by multiple backward passes with retain graph = True, and finally using backward final time before optimizer.step.

I do this because I compute a heavy loss that causes cuda OOM before reaching the final backward if I accumulate the ‘partial’ losses before using backward on everything together.

I now want to train on multiple GPUs,but from what I see here it seems that I can’t use no_sync for my use case, as the forward pass also need to be inside the context.
Another point is that the number of backwards between the final one won’t be the same across the different inputs.

Is there a way for me to disable syncing until the final backward pass?

@bary DistributedDataParallel — PyTorch 2.0 documentation

Maybe you could try use static_graph=True option as that supports re-entrant backward, but I am not sure if that would help disable the gradients sync to the final backward pass.

cc @Yanli_Zhao

What I ended up doing as a workaround is to have a dummy forward pass before the final backward pass - this made the gradients different the final backward, and equal after it.