FSDP multi-node comms overhead

I am trying to use FSDP HYBRID_SHARD for multi-node training and I am seeing unexpectedly large comms overheads. I am currently trying to diagnose and hoping someone here can help understand what needs to be changed.

Not sure what the best way to share my pytorch profiler trace is… here is a screenshot. You can see there is a long, green, block (stream 24) of NCCL AllReduce_Sum_bf16 ops that is the bottleneck - I believe these are the cross-node gradient all-reduce operations? So I guess, a few questions:

  1. I thought in FSDP the gradient all-reduce was executed as a reduce-scatter followed by an all-gather. I see these in smaller blocks throughout the backwards pass, I assume these are the intra-node gradient syncs. Why does cross-node not use the same scheme?
  2. We are definitely not seeing the bandwidth utilized the way we expect. This is an 11B T5 model on a cluster that should have ~1Tb inter-node bandwidth (5x 200Gbps NICs). Using nccl-tests we measured ~600 Gb of useable bandwidth - plenty to make this workload GPU limited, but if I convert the time these all-reduces take into a Gbps number its about 170 Gbps - maybe we’re only using one NIC for some reason? Why would this be the case in our torch code but not when running the nccl tests?

Anyway, I know there are a lot of environment variables and other configs that could be causing the bottleneck, but if anyone here has ideas that would be amazing - looking forward to being able to fully utilize our resources!