Difference between FullyShardedDataParallel and ZeroRedundancyOptimizer?

Is my understanding of FullyShardedDataParallel and ZeroRedundancyOptimizer correct?

FullyShardedDataParallel is to shard the model’s parameters, which can save GPU memory during the forward process.
ZeroRedundancyOptimizer is to shard state parameters of the optimizer, which is to save GPU memory during backward and gradient update processes

Both ZeroRedundancyOptimizer and FullyShardedDataParallel are PyTorch classes based on the algorithms from the “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models” paper.

From an API perspective, ZeroRedunancyOptimizer wraps a torch.optim.Optimizer to provide ZeRO-1 semantics (i.e. P_{os} from the paper). In contrast, FullyShardedDataParallel wraps a torch.nn.Module to provide ZeRO-3 semantics (i.e. P_{os+g+p} from the paper) (when using sharding_strategy=ShardingStrategy.FULL_SHARD).

ZeRO-1 only shards optimizer states. ZeRO-3 shards optimizer states, gradients, and parameters. ZeRO-1 saves memory during the optimizer step, and ZeRO-3 saves memory throughout forward, backward, and optimizer steps.

I can provide more details if you have further questions.

1 Like