The optional parameter
overlap_with_ddp is not well illustrated in the document.
I don’t understand the requirement in the document:
(2) registering a DDP communication hook constructed from one of the functions in
So I just ignored it.
When I was trying to test it, a warning came:
step() should not be included in the training loop when
Then after I removed
optimizer.step(), the warning disappeared, but the parameters are not optimized, or fixed.
I am confused about how to make use of this parameter properly.
Is it possible to add an example to explain this?
cc @awgu for using Zero optimizer
Apologies for the delay and lack of documentation. The feature was added experimentally and not revisited.
I think the reason you are not seeing the parameters being optimized is because you skipped that step labeled as (2). You have to register one of the two DDP communication hooks in
ddp_zero_hook.py: namely, either
Let me provide some pseudocode for how you might use this:
from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
model: nn.Module = ... # define a model
ddp_model = DDP(model, device_ids=[rank])
zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.SGD, overlap_with_ddp=True, lr=1e-3, momentum=0.9)
hook = hook_with_zero_step(allreduce_hook, ddp_model, zero_optim)
# The first two iterations are warmup and do nothing
for i in range(2):
output = ddp_model(dummy_input)
loss = output.sum() # compute arbitrary loss
loss.backward() # compute arbitrary gradients but no optimizer update
# Now we may train overlapping DDP with ZeRO
for (input, label) in dataloader:
zero_optim.zero_grad() # need to clear the arbitrary gradients from above
output = ddp_model(input)
loss = loss_fn(output, label) # compute actual loss
loss.backward() # compute actual gradients followed by an optimizer update
I hope this helps, and let me know if you have followup questions.