Hello together,
I’m currently doing some profiling with a FSDP wrapped model. I use both, the full_shard
and hybrid_shard
policy.
My model has 12 units so I have 12 FSDP.foward
calls or blocks, similar for the backward pass. These forwards blocks consists always of the same three functions:
pre_forward
(launchesall_gather
)- compute
post_forward
I would expect to have 12 all_gather
calls triggered from the 12 pre_forward
functions part of the 12 forward blocks, but at the beginning I have one additional single pre_forward
which launches an additional all_gather
. So overall I have 13 all_gather
and 12 units.
In the backward pass I find as expected 12 all_gather
calls but 13 reduce_scatter
. One additionally triggered from an separate post_backward
after the last backward block. If I use the hybrid_shard
policy I have as expected some additional all_reduces
, but again I have 13 of them.
My question is:
Does someone know what the 13th all_gather
in forward and the 13th reduce_scatter
and all_reduce
in the backward is doing?
Why do I have these additional communication calls and why not at all_gather
Would appreciate any kind of help!