I’m struggling to understand how a local batch of data is processed by a single GPU in a Fully Sharded Data Parallel training. As far as I know, a GPU only holds a shard of the model params, it processes its local batch of data, and it belongs to a certain FSDP unit. In general, there could be multiple FSDP units (see this diagram), each holding a different shard of the model. During the forward pass, the all_gather will collect all the params from the given FSDP unit into each GPU belonging to it. After the all_gather, those GPUs still have only a fraction of the whole model, as the all_gather operates inside a single FSDP unit alone. Consequently, local batches of data on those GPUs will be processed by that fraction of the model. When does those batches will be processed by the full model params?
After the all_gather , those GPUs still have only a fraction of the whole model ,
The GPUs have a fraction of the whole model, but it is along a different axis. You can think of it like pipelining. For example, if we had a sequence of FSDP units (e.g. transformer blocks), FSDP would all-gather one FSDP unit (one transformer block), run forward with the local batch, free the FSDP, all-gather the next FSDP unit, run forward with the local batch, etc.
FSDP units are chosen on the unsharded model, partitioning it into units/groups without sharding the parameters yet. After they are chosen, the parameters within that unit are fully sharded across GPUs. The all-gather operation undoes the sharding of parameters, but not the partitioning of the model into groups. To run forward/backward, only the activate unit needs to be all-gathered.
Thank a lot for the reply. I’m still confused a bit.
For example, if we had a sequence of FSDP units (e.g. transformer blocks), FSDP would all-gather one FSDP unit (one transformer block), run forward with the local batch, free the FSDP, all-gather the next FSDP unit, run forward with the local batch, etc.
Ok, once all FSDP units are done, we have a complete forward pass of the model on a local batch. Right? Then is the process repeated to account for every other local batch on other GPUs?
Yes, since FSDP (fully sharded data parallel) is a form of data parallelism, each worker is processing its own local batch.
E.g., if you had 2 GPUs, then GPU0 and GPU1 all-gather the 0th FSDP unit, each compute forward through their respective local batches, and then proceed to the 1st FSDP unit, etc.
Then is the process repeated to account for every other local batch on other GPUs?
In other words, this is happening in parallel on all GPUs.
GPUs 0 and 1 do all-gather of 1st FSDP unit even though they belong to the 0th unit? And then they all-gather the 2nd and so on?
GPUs 0 and 1 do all-gather of 1st FSDP unit even though they belong to the 0th unit?
I am not sure I follow this. Every GPU has a shard of every FSDP unit.
Said another way (if the FSDP unit terminology is confusing): you can think of each FSDP unit as a module/layer. Suppose your model is a sequence of layers. You can partition your model layer by layer. For each layer, all GPUs all-gather, run forward, and free. Each GPU has a shard of each layer always; only when they all-gather for computation, all GPUs have the entire layer.
There it is. I thought every GPU only had a single shard, corresponding to a piece of the model in its FSDP unit. There is no such thing as its FSDP unit. Every GPUs has a piece (i.e. a shard) of every FSDP unit. The FSDP unit partitioning is like a logical partitioning of the model. There is no assignment between a given GPU and a given FSDP unit.
As you initially said
The GPUs have a fraction of the whole model, but it is along a different axis.
Assuming we see a model as a stack of layers, we can think about the number of layers in the stack as the vertical axis of the model, while the layer width is the horizontal axis. A GPU has full view of the vertical axis (i.e. a shard of each FSDP unit), but only a partial view along the horizontal axis. After running all-gather, the same GPU temporarily gains the full view along the horizontal axis for a single FSDP unit (i.e. a full layer of the model), whereas it still has only a shard for each of the other FSDP units.
Many thanks for your patience.
This sounds right to me!