Will pytorch support real ZeRO-1 natively?

Hi, I found that the torch.optim.ZeroRedundancyOptimizer does not implement the real ZeRO-1 of DeepSpeed.

the training process of torch.optim.ZeroRedundancyOptimizer is as below (correct me if I’m wrong):

  1. Calculate grad of all layers on all devices, and do the AllReduce for all grads
  2. Shard the total parameters by layer into different devices, and do the optimizing of Adam for corresponding sharded layers on certain devices.
  3. Broadcast the results of Adam optimizing, which are parameters, to the other devices.

Thus, the ZeroRedundancyOptimizer got extra communication of broadcast than the baseline.

But as I know, the steps of origin ZeRO-1 from deepspeed are:

  1. Calculate grad of all layers on all devices, and do the ReduceScatter for all grads, which result in a sharded grad tensor for each parameter tensor on every device.
  2. Optimize the corresponding parameters with sharded grad tensors
  3. Run an AllGather communication for optimized parameters

Compared to baseline, it does not have extra communication.

So why does PyTorch implement the former design instead of the latter? Is there a road map for migration to the latter implementation?

cc @mrshenli who is the SME.

@Bean what you mentioned is like DeepSpeed Zero2 algorithm. We are releasing a new API called torch.distributed.fsdp.FullyShardedDataParallel in PyTorch 1.11, which is similar to Zero3. We are also working in progress to make it easily configed as an algorithm similar to Zero 2 as well, basically shard optimizer states and grad only, and it quires all gather in the forward pass and reduce scatter in the backward pass.

Thanks for your reply, looking forward to torch 1.11 FSDP API!