Understanding FSDP prefetching

How can I get a better understanding of the prefetching process during FSDP1/2 forward and backward?

I hope to learn the following:

  • The maximum amount of memory overhead prefetching will consume
  • The point at which a given parameter’s prefetch will be triggered
  • If FSDP1/2 exposes any hooks for when parameter / gradient fetching is initiated.

I should write a more detailed post on this at some point. Let me get back to you.

3 Likes

Maybe It’s a bit old, but check this post! FSDP & CUDACachingAllocator: an outsider newb perspective - distributed - PyTorch Developer Mailing List

1 Like

Currently, I do not know whether FSDP2 will actively prefetch the gradients for the last layer (the first layer to be computed in the backward pass) in a training step.

I have not inspected fsdp2’s source code in detail, and imagine the following "What if"s:

  • FSDP never reshards the gradients of the root FSDP module, but only shards inner wrapped FSDP submodules
  • FSDP starts to gather the root module’s grads when .backward() occurs, leading to an initial blocking step before any computation can occur
  • FSDP is aware of the computational graph (even in eager?) and gathers an appropriate FSDPParamGroup ahead of time, after it has tracked the first backward()

Are any of the above cases true? What really happens?