How to use torch.distributed.optim.ZeroRedundancyOptimizer with overlap_with_ddp=True?

The optional parameter overlap_with_ddp is not well illustrated in the document.
I don’t understand the requirement in the document:

(2) registering a DDP communication hook constructed from one of the functions in ddp_zero_hook.py

So I just ignored it.

When I was trying to test it, a warning came:

WARNING:root:step() should not be included in the training loop when overlap_with_ddp=True

Then after I removed optimizer.step(), the warning disappeared, but the parameters are not optimized, or fixed.

I am confused about how to make use of this parameter properly.

Is it possible to add an example to explain this?

cc @awgu for using Zero optimizer

Apologies for the delay and lack of documentation. The feature was added experimentally and not revisited.

I think the reason you are not seeing the parameters being optimized is because you skipped that step labeled as (2). You have to register one of the two DDP communication hooks in ddp_zero_hook.py: namely, either hook_with_zero_step() or hook_with_zero_step_interleaved().

Let me provide some pseudocode for how you might use this:

from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
    hook_with_zero_step,
    hook_with_zero_step_interleaved,
)
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
    allreduce_hook,
)
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

model: nn.Module = ...  # define a model
ddp_model = DDP(model, device_ids=[rank])
zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.SGD, overlap_with_ddp=True, lr=1e-3, momentum=0.9)
hook = hook_with_zero_step(allreduce_hook, ddp_model, zero_optim)
ddp_model.register_comm_hook(state=None, hook=hook)
# The first two iterations are warmup and do nothing
for i in range(2):
    output = ddp_model(dummy_input)
    loss = output.sum()  # compute arbitrary loss
    loss.backward()  # compute arbitrary gradients but no optimizer update
# Now we may train overlapping DDP with ZeRO
for (input, label) in dataloader:
    zero_optim.zero_grad()  # need to clear the arbitrary gradients from above
    output = ddp_model(input)
    loss = loss_fn(output, label)  # compute actual loss
    loss.backward()  # compute actual gradients followed by an optimizer update

I hope this helps, and let me know if you have followup questions.

1 Like

Hey there @agu

We followed the example above, but it failed because we used two parameter groups (one for decay and one for no-decay).

The exact error was:

RuntimeError: Specifying `params_per_rank` only supports a single parameter group

How do we modify the pseudocode to support multiple parameter groups?

The ZeroRedundancyOptimizer worked correctly with multiple parameter groups when we disabled overlap_with_ddp, but we need ludicrous speed :wink: thanks!

Thanks!

Unfortunately, we are not actively developing ZeroRedundancyOptimizer anymore, so supporting multiple parameter groups is unlikely.

You can consider switching your DDP setup to PyTorch’s FullyShardedDataParallel with ShardingStrategy.SHARD_GRAD_OP, which supports multiple parameter groups with use_orig_params=True.

1 Like