How does fsdp algorithm work?

Hi
I’m trying to understand fsdp but can’t find any good resources. I’ve broken it down into steps can someone tell me if I’m right??
So lets say that I have 4 gpu’s and a model that can’t fit into one gpu so it gets split into 2 shards.

  1. load weights from shard 1 into each gpu
  2. split batch of data into 4 parts
  3. do forward pass with part 1 on gpu 1, part 2 on gpu 2 etc.
  4. save intermediate outputs for backprop locally. They are different on each gpu because of different data
  5. offload weights from shard 1
  6. load weights from shard 2 into each gpu
  7. do forward pass with outputs from step 4 and save intermediate outputs. Again they are different on each gpu because of different inputs. outputs from step 4 are still saved locally
  8. calculate loss
  9. load optimizer state regarding weights from shard 2
  10. calculate gradients for weights from shard 2. They are different on each gpu
  11. synchronize gradients so they are the same on all gpus
  12. update weights on each gpu locally. after update each gpu has the same weights
  13. offload optimizer state and weights from shard 2
  14. load weights from shard 1 and optimizer state regarding them.
  15. calculate gradients for weights from shard 1. They are different on each gpu
  16. synchronize gradients so they are the same on all gpus
  17. update weights on each gpu locally. after update each gpu has the same weights
  18. offload optimizer state
  19. load another batch of data and split into 4 parts
  20. repeat steps 3-19 until stopping criteria is met

Am I right? Would fsdp work like that in a situation I described?

This tutorial and these 10 videos might give you the details about its implementation.

If you still have specific FSDP questions after the resources linked, I can try to help.

1 Like

I do
My main question is about general idea how it works. So let’s say that I have 8 gpus and model that can’t fit fully into any of them each can fit half of it. Will fsdp create 2 shards, load shard 1 on all gpus do a forward pass offload shard 1 and do the same with shard 2 or will it load shard 1 on 4 gpus and shard 2 on the other 4 gpus?

At high level FSDP works as follow:

In constructor

  • Shard model parameters and each rank only keeps its own shard

In forward path

  • Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit

What does it mean to collect all shards from all ranks to recover the full parameter in this FSDP unit?

What is an FSDP unit?

  • Run forward computation
  • Discard parameter shards it has just collected

In backward path

  • Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
  • Run backward computation
  • Run reduce_scatter to sync gradients
  • Discard parameters.

This looks good overall. A few notes:

  • For the backward pass, FSDP discards parameters before reduce-scattering gradients to free that memory earlier.

What does it mean to collect all shards from all ranks to recover the full parameter in this FSDP unit?

This means that every rank calls all-gather with its local shard as input. After the all-gather, every rank has the full unsharded parameters for the FSDP unit.

What is an FSDP unit?

At a high level, an FSDP unit is a layer or collection of layers whose parameters are communicated together. Practically, an FSDP unit refers to one FullyShardedDataParallel instance. In the current design, one FullyShardedDataParallel instance wraps one nn.Module and constructs a single FlatParameter out of the nn.Module’s parameters (that are not already assigned to a nested FullyShardedDataParallel instance).

So let’s say that I have 8 gpus and model that can’t fit fully into any of them each can fit half of it. Will fsdp create 2 shards, load shard 1 on all gpus do a forward pass offload shard 1 and do the same with shard 2 or will it load shard 1 on 4 gpus and shard 2 on the other 4 gpus?

There are two dimensions to consider: (1) how many “units” you partition your model into and (2) how many ranks/workers you shard each “unit” across.

  • For (2), the default ShardingStrategy.FULL_SHARD and the ShardingStrategy.SHARD_GRAD_OP strategies shard across all ranks/workers. In your example, this means sharding over the 8 GPUs.
  • For (1), this depends on how you apply FullyShardedDataParallel to your model. If you only wrap once at the top-level, then there is only one “unit”. However, generally you want more than one “unit” to achieve communication-computation overlap and decrease peak memory. If you only have one “unit”, then the entire model is materialized for forward or for backward, so the entire model contributes to peak memory.

Putting this together, suppose you have a model that you partition into 3 “FSDP units”, denoted unit_0, unit_1, and unit_2, and you have 8 GPUs. The mental model can be like:
For i in {0, 1, 2}:

  • unit_i all-gather parameters → all 8 ranks have all parameters of unit_i
  • unit_i forward
  • unit_i free parameters → rank i only has its 1/8 of the parameters of unit_i

For i in {2, 1, 0}:

  • unit_i all-gather parameters → all 8 ranks have all parameters of unit_i
  • unit_i backward → all ranks have the full gradient for its local batch (representing a partial result)
  • unit_i free parameters → rank i only has its 1/8 of the parameters of unit_i
  • unit_i reduce-scatter gradients → all ranks have its 1/8 of the gradient for the global batch

In practice, due to the nested structure of nn.Modules, you may not see such a simple pattern order of pre/post-forward/backward logic, but the above gives the general mental model.

Would FSDP be ideal in the circumstance where I have a small parameter size, but large calculations within the layers. Meaning that while I can fit all of my layers into memory, trying to store multiple layers gradients causes OOM? Essentially what I’m wanting to do is have one layer processing on one GPU at once and then remove it once it has been processed.

I have tried activation checkpointing but it leads to the models loss not improving.

Thanks
Miles

So if I understand correctly:

Putting this together, suppose you have a model that you partition into 3 “FSDP units”, denoted unit_0, unit_1, and unit_2, and you have 8 GPUs. The mental model can be like:
For i in {0, 1, 2}:

  • unit_i all-gather parameters → all 8 ranks have all parameters of unit_i
  • unit_i forward
  • unit_i free parameters → rank i only has its 1/8 of the parameters of unit_i

For i in {2, 1, 0}:

  • unit_i all-gather parameters → all 8 ranks have all parameters of unit_i
  • unit_i backward → all ranks have the full gradient for its local batch (representing a partial result)
  • unit_i free parameters → rank i only has its 1/8 of the parameters of unit_i
  • unit_i reduce-scatter gradients → all ranks have its 1/8 of the gradient for the global batch

In practice, due to the nested structure of nn.Modules, you may not see such a simple pattern order of pre/post-forward/backward logic, but the above gives the general mental model.

Each GPU is not processing different parts of the model, but is processing a shard of the larger model. E.g. if I split my 200 layer model into 4 equal shards of 50 layers, GPU 1 and 2 would both perform the forward for layers 1-50 and then 51-100 etc and then in the backwards start with 151-200 backwards

Therefore any multi GPU FSDP training should also work fine with single GPU FSDP training as it is the same thing but without the need to gather/scatter parameters between GPUs?

FSDP would not be the right solution for that case since it sounds like you have large activations. (Gradient memory is the same size as parameter memory, so small parameter size means small gradient size.)

Activation checkpointing or some kind of tensor parallelism both can help, where activation checkpointing may be easier to integrate. If your loss is not improving, then that is most likely a bug somewhere either in model code or in the integration of activation checkpointing. It should not change the numerics.

Yes, for the single GPU case, FSDP devolves to be the same as local training. There is no sharding, and there are no collectives (all-gather / reduce-scatter). As such, there is no reason to use FSDP then.

I think I might be misunderstanding. Does FSDP not split a large model into smaller chunks? Meaning that we can train large models one chunk/shard at a time provided that this chunk/shard fits into memory?

In short, the ‘smaller chunks’ are only relevant if there is sharding involved. FSDP’s semantics ensures that each chunk is restored to look like an equivalent local version (without sharding) before computation. However, if the world size (i.e. number of GPUs) is 1, then there is no sharding.

FSDP is not doing anything special to permit you to only partially train some layers/modules of a model.

Oh okay. Much appreciated for clearing up my misconceptions about FSDP. So what does the CPUOffload actually do?

CPUOffload means that the parameters are offloaded to CPU when not used for computation.

One way to think about it is that under FSDP, each parameter has two states: sharded and unsharded. The invariant is that the parameter must be unsharded for computation, but outside of computation, FSDP can return it to the sharded state (for memory savings).

  • In the vanilla form, the unsharded state is the original float32 parameter on GPU, and the sharded state is the worker’s 1 / world_size fraction of the float32 parameter on GPU.
  • If we enable CPU offloading, then the sharded state changes to being instead on the CPU.
  • If we enable parameter mixed precision, then the unsharded state changes to being in the low precision (e.g. bfloat16).
1 Like

where are the activations(result from the forward pass) of the FSDP unit stored in? Like according to my understanding there each gpu recieves different unit of data so when the parameters are unsharded for forward pass calculation, the results of the forward pass remain on the gpu and each gpu would have different result as each gpu recieve different piece of data ?

Since FSDP is a form of data parallelism, each rank has its own activations for its local batch. Your understanding is correct!