Sorry, I missed the code. If you do not use nested FSDP wrapping, then you will incur 1 + 1 / W memory overhead from the parameters for W workers (due to FSDP always maintaining the 1 / W shard), which could be why you see higher GPU memory usage with FSDP.
The FSDP wrapping effectively determines the max parameter size that is unsharded (all-gathered) at once, so if you only wrap at the top-level, then all parameters contribute to that max size. Therefore, nested wrapping is necessary to decrease the max memory contribution from parameters.
For ZeroRedundancyOptimizer, the implementation is different. The optimizer states are greedily partitioned across ranks to minimize unevenness. It does not shard parameters or gradients.
Thanks for your reply, it’s really helpful. But I still have some questions. Is ZeroRedundancyOptimizer an implementation of ZeRO-1? If so, why not continue implementing ZeRO-2 and ZeRO-3 based on ZeroRedundancyOptimizer instead of introducing FSDP separately? Is there a reason for this?
One reason is that ZeroRedundancyOptimizer is a torch.optim.Optimizer, but in order to shard parameters/gradients, we need something that interfaces with nn.Module. Hence, FullyShardedDataParallel is an nn.Module (wrapper).
Hello, I still have a question that I don’t understand. When FSDP does not do nested warpping, the memory consumption is higher than ZeroRedundancyOptimizer, but FSDP actually includes ZeRO-1. Is this caused by the different implementation between FSDP and ZeroRedundancyOptimizer? Where is the extra memory consumption of FSDP reflected?
It is an implementation detail of FSDP. See the description from above:
If you do not use nested FSDP wrapping, then you will incur 1 + 1 / W memory overhead from the parameters for W workers (due to FSDP always maintaining the 1 / W shard), which could be why you see higher GPU memory usage with FSDP.