Currently, I do not know whether FSDP2 will actively prefetch the gradients for the last layer (the first layer to be computed in the backward pass) in a training step.
I have not inspected fsdp2’s source code in detail, and imagine the following "What if"s:
FSDP never reshards the gradients of the root FSDP module, but only shards inner wrapped FSDP submodules
FSDP starts to gather the root module’s grads when .backward() occurs, leading to an initial blocking step before any computation can occur
FSDP is aware of the computational graph (even in eager?) and gathers an appropriate FSDPParamGroup ahead of time, after it has tracked the first backward()
Are any of the above cases true? What really happens?