How to use torch.distributed.optim.ZeroRedundancyOptimizer with overlap_with_ddp=True?

The optional parameter overlap_with_ddp is not well illustrated in the document.
I don’t understand the requirement in the document:

(2) registering a DDP communication hook constructed from one of the functions in

So I just ignored it.

When I was trying to test it, a warning came:

WARNING:root:step() should not be included in the training loop when overlap_with_ddp=True

Then after I removed optimizer.step(), the warning disappeared, but the parameters are not optimized, or fixed.

I am confused about how to make use of this parameter properly.

Is it possible to add an example to explain this?

cc @awgu for using Zero optimizer

Apologies for the delay and lack of documentation. The feature was added experimentally and not revisited.

I think the reason you are not seeing the parameters being optimized is because you skipped that step labeled as (2). You have to register one of the two DDP communication hooks in namely, either hook_with_zero_step() or hook_with_zero_step_interleaved().

Let me provide some pseudocode for how you might use this:

from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

model: nn.Module = ...  # define a model
ddp_model = DDP(model, device_ids=[rank])
zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.SGD, overlap_with_ddp=True, lr=1e-3, momentum=0.9)
hook = hook_with_zero_step(allreduce_hook, ddp_model, zero_optim)
ddp_model.register_comm_hook(state=None, hook=hook)
# The first two iterations are warmup and do nothing
for i in range(2):
    output = ddp_model(dummy_input)
    loss = output.sum()  # compute arbitrary loss
    loss.backward()  # compute arbitrary gradients but no optimizer update
# Now we may train overlapping DDP with ZeRO
for (input, label) in dataloader:
    zero_optim.zero_grad()  # need to clear the arbitrary gradients from above
    output = ddp_model(input)
    loss = loss_fn(output, label)  # compute actual loss
    loss.backward()  # compute actual gradients followed by an optimizer update

I hope this helps, and let me know if you have followup questions.