Communication behavior of FSDP

Hello together,

I’m currently doing some profiling with a FSDP wrapped model. I use both, the full_shard and hybrid_shard policy.

My model has 12 units so I have 12 FSDP.foward calls or blocks, similar for the backward pass. These forwards blocks consists always of the same three functions:

  • pre_forward (launches all_gather)
  • compute
  • post_forward

I would expect to have 12 all_gather calls triggered from the 12 pre_forward functions part of the 12 forward blocks, but at the beginning I have one additional single pre_forward which launches an additional all_gather. So overall I have 13 all_gather and 12 units.

In the backward pass I find as expected 12 all_gather calls but 13 reduce_scatter. One additionally triggered from an separate post_backward after the last backward block. If I use the hybrid_shard policy I have as expected some additional all_reduces, but again I have 13 of them.

My question is:

Does someone know what the 13th all_gather in forward and the 13th reduce_scatter and all_reduce in the backward is doing?
Why do I have these additional communication calls and why not at all_gather

Would appreciate any kind of help!