Hi, I found that the torch.optim.ZeroRedundancyOptimizer does not implement the real ZeRO-1 of DeepSpeed.
the training process of torch.optim.ZeroRedundancyOptimizer is as below (correct me if I’m wrong):
- Calculate grad of all layers on all devices, and do the AllReduce for all grads
- Shard the total parameters by layer into different devices, and do the optimizing of Adam for corresponding sharded layers on certain devices.
- Broadcast the results of Adam optimizing, which are parameters, to the other devices.
Thus, the ZeroRedundancyOptimizer got extra communication of broadcast than the baseline.
But as I know, the steps of origin ZeRO-1 from deepspeed are:
- Calculate grad of all layers on all devices, and do the ReduceScatter for all grads, which result in a sharded grad tensor for each parameter tensor on every device.
- Optimize the corresponding parameters with sharded grad tensors
- Run an AllGather communication for optimized parameters
Compared to baseline, it does not have extra communication.
So why does PyTorch implement the former design instead of the latter? Is there a road map for migration to the latter implementation?