The documentation which related to torch.distributed.fsdp.BackwardPrefetch
strategy seems to imply that the memory allocated for gradients is freed before the ReduceScatter phase (or I am just not understanding)
I have drawn what I believe to be an example of the strategies. CUDA stream 1 is responsible for the Backward gradient computation kernels, and CUDA stream 2 runs the AllGather and ReduceScatter collectives. I have also drawn what I believe to be the duration of the P (parameter) allocations and G (gradient) allocations, so illustrate when they are freed.
BACKWARD_PRE
BACKWARD_POST
The documentation claims that BACKWARD_POST
saves memory because peak memory has resident the
- The next set of parameters
- The current set of gradients in memory
But as we are overlapping ReduceScatter with next gradient calculation my intuition would be that at some point we must have
- The current parameters
- The current gradients (mid-computation)
- The previous gradients (mid-reduction)
As can be seen on the diagram. I would appreciate anyone letting me know what I am missing here!