Apologies for the delay and lack of documentation. The feature was added experimentally and not revisited.
I think the reason you are not seeing the parameters being optimized is because you skipped that step labeled as (2). You have to register one of the two DDP communication hooks in ddp_zero_hook.py: namely, either hook_with_zero_step() or hook_with_zero_step_interleaved().
Let me provide some pseudocode for how you might use this:
from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
hook_with_zero_step,
hook_with_zero_step_interleaved,
)
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
allreduce_hook,
)
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
model: nn.Module = ... # define a model
ddp_model = DDP(model, device_ids=[rank])
zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.SGD, overlap_with_ddp=True, lr=1e-3, momentum=0.9)
hook = hook_with_zero_step(allreduce_hook, ddp_model, zero_optim)
ddp_model.register_comm_hook(state=None, hook=hook)
# The first two iterations are warmup and do nothing
for i in range(2):
output = ddp_model(dummy_input)
loss = output.sum() # compute arbitrary loss
loss.backward() # compute arbitrary gradients but no optimizer update
# Now we may train overlapping DDP with ZeRO
for (input, label) in dataloader:
zero_optim.zero_grad() # need to clear the arbitrary gradients from above
output = ddp_model(input)
loss = loss_fn(output, label) # compute actual loss
loss.backward() # compute actual gradients followed by an optimizer update
I hope this helps, and let me know if you have followup questions.
Unfortunately, we are not actively developing ZeroRedundancyOptimizer anymore, so supporting multiple parameter groups is unlikely.
You can consider switching your DDP setup to PyTorch’s FullyShardedDataParallel with ShardingStrategy.SHARD_GRAD_OP, which supports multiple parameter groups with use_orig_params=True.