Small clarification about FSDP docs

Hello!

On the following post below, Figure 1 says that for each layer N, all parameters are gathered (all-gather), forward happens, then all weights/full weights are freed. However, all weights besides layer N weights are freed, right ? Otherwise, in my understanding, they won’t be all-gathered on the backward pass.

I wonder if the “besides layer N weight” are implicit or not in the post. :smiley:

Thanks!
Lucas

Could you link the figure?

I may be misunderstanding what you are saying, but why do you think that if layer N’s parameters are freed, then they will not be all-gathered in the backward pass?

The backward pass has an additional all-gather for the parameters that is separate from the one in the forward pass. It would not make sense to say that all parameters besides layer N’s parameters are freed because the parameters besides layer N’s parameters have already been freed / are not in memory – that is how FSDP is able to save memory.

Hello @agu

I’m sorry, I’ve posted without the link.This is the Figure: https://pytorch.org/assets/images/fsdp_workflow.png
From: Introducing PyTorch Fully Sharded Data Parallel (FSDP) API | PyTorch

I’m trying to have a global understanding of the operation by reasoning about that Figure and this one: https://engineering.fb.com/wp-content/uploads/2021/07/FSDP-graph-2a.png

I may be misunderstanding what you are saying, but why do you think that if layer N ’s parameters are freed, then they will not be all-gathered in the backward pass?

I’m getting into the details how PyTorch uses collective communication, but for the allgather to happen (in OpenMPI, for instance) each worker (FSDP instance here) must have a valid (i.e. not freed) sending buffer, right ? Am I overlooking/misunderstanding some point ? (I probably am :smiley: )

Thanks!
Lucas

I think I understand the confusion.

Each worker keeps a local shard of the model parameters for the layer so that it can contribute that shard in the all-gather. When the diagram says that the model parameters are freed after the forward pass, everything except for the rank’s local shard is freed.