Cases where Fully Sharded Data Parallel is not mathematically equivalent to local training

In the PyTorch FSDP paper section 7.2.1, they say

FSDP cannot ensure that it always achieves the same mathematical equivalence as local training, especially with respect to the optimizer computation. This stems from the fact that the optimizer step operates on the sharded parameters, whose data layout is a function of FSDP’s FlatParameter sharding algorithm that does not respect individual parameter boundaries. As a result, any optimizer computation that depends on an original parameter’s unsharded value (e.g. vector norm), its tensor structure (e.g. approximate second-order optimizers), or require global states over all parameters will become invalid.

I have 2 questions

  1. How does FlatParameter sharding algorithm not respect individual parameter boundaries?
  2. What is the implication of the last sentence? What kind of models cannot be trained using FSDP?

cc @agu @weifengpy for FSDP questions

Support we shard param1 = [1.0] and param2 = [2.0, 3.0, 4.0] over two gpus

  • 1st rank have flat_param1 = [1.0, 2.0]. This rank has whole param1, but part of param2
  • 2nd rank have flat_param2 = [3.0, 4.0]. This rank has part of param2, but nothing of param1
  • one rank might only have part of the one param

It is suggests point-wise optimizers can be used out of the box, including SGD, Adam, AdamW.
If the optimizer has non-pointwise op, say vector norm for grad clipping, we need NCCL all-reduce norms across ranks