Question About Backward–ReduceScatter Overlap in FSDP Figure 5

I have a question regarding Figure 5 in the PyTorch FSDP paper (Overlap Communication and Computation).

In the timeline, BWD1 and RS1 appear to start concurrently. My understanding is that Reduce-Scatter should only be triggered after the gradients for that FSDP unit are produced during the backward pass. This seems consistent with other units in the figure, such as BWD2 → RS2 and BWD0 → RS0, where RS clearly depends on the completion (or at least the gradient availability) of the corresponding backward computation.

Could you clarify why BWD1 and RS1 are shown as overlapping? Does RS1 begin as soon as partial gradients for that unit are ready (e.g., via gradient bucketing or chunking), or is there another scheduling mechanism involved that allows communication to overlap with the remaining backward computation?

This is true.

So actually BWD computation and RS can still overlap. Say you are doing backward for a linear layer (y = x @ weight.T), we get the equations:

dWeight = y_grad.T @ x
dX = y_grad @ weight

FSDP will call RS when dWeight is done, but we still need to compute dX. So in reality, you will have some overlap.

I’m not sure if this was the intention for the diagram in the paper. I think it might just be a mistake :slight_smile:

Thank you very much for your reply. It makes sense for me. There will be some overlap here.