FSDP Backward pre-fetch documentation

The documentation which related to torch.distributed.fsdp.BackwardPrefetch strategy seems to imply that the memory allocated for gradients is freed before the ReduceScatter phase (or I am just not understanding)

I have drawn what I believe to be an example of the strategies. CUDA stream 1 is responsible for the Backward gradient computation kernels, and CUDA stream 2 runs the AllGather and ReduceScatter collectives. I have also drawn what I believe to be the duration of the P (parameter) allocations and G (gradient) allocations, so illustrate when they are freed.

BACKWARD_PRE

BACKWARD_POST

The documentation claims that BACKWARD_POST saves memory because peak memory has resident the

  • The next set of parameters
  • The current set of gradients in memory

But as we are overlapping ReduceScatter with next gradient calculation my intuition would be that at some point we must have

  • The current parameters
  • The current gradients (mid-computation)
  • The previous gradients (mid-reduction)

As can be seen on the diagram. I would appreciate anyone letting me know what I am missing here!

I think actually BACKWARD_POST does not really save peak memory in a principled way. I think BACKWARD_POST is actually just prefetching incorrectly when transitioning from one module depth to another (depth when viewing the model as a tree of modules). [FSDP] incorrect backward prefetch order when using BackwardPrefetch.BACKWARD_POST · Issue #108190 · pytorch/pytorch · GitHub