Both ZeroRedundancyOptimizer
and FullyShardedDataParallel
are PyTorch classes based on the algorithms from the “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models” paper.
From an API perspective, ZeroRedunancyOptimizer
wraps a torch.optim.Optimizer
to provide ZeRO-1 semantics (i.e. P_{os} from the paper). In contrast, FullyShardedDataParallel
wraps a torch.nn.Module
to provide ZeRO-3 semantics (i.e. P_{os+g+p} from the paper) (when using sharding_strategy=ShardingStrategy.FULL_SHARD
).
ZeRO-1 only shards optimizer states. ZeRO-3 shards optimizer states, gradients, and parameters. ZeRO-1 saves memory during the optimizer step, and ZeRO-3 saves memory throughout forward, backward, and optimizer steps.
I can provide more details if you have further questions.