Difference between FullyShardedDataParallel and ZeroRedundancyOptimizer?

Both ZeroRedundancyOptimizer and FullyShardedDataParallel are PyTorch classes based on the algorithms from the “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models” paper.

From an API perspective, ZeroRedunancyOptimizer wraps a torch.optim.Optimizer to provide ZeRO-1 semantics (i.e. P_{os} from the paper). In contrast, FullyShardedDataParallel wraps a torch.nn.Module to provide ZeRO-3 semantics (i.e. P_{os+g+p} from the paper) (when using sharding_strategy=ShardingStrategy.FULL_SHARD).

ZeRO-1 only shards optimizer states. ZeRO-3 shards optimizer states, gradients, and parameters. ZeRO-1 saves memory during the optimizer step, and ZeRO-3 saves memory throughout forward, backward, and optimizer steps.

I can provide more details if you have further questions.

1 Like