ZERO optimizer.consolidate_state_dict() will hang

I try to use ZERO, however it hang when optimizer.consolidate_state_dict.

Why this code will hang? Is there something silly I did? TIA!

r"""Run ``torchrun --nproc_per_node=2 test.py``
"""
import torch
from torch import distributed as dist
from torchvision.models import resnet18
from torch import nn
from torch.optim import Adam
from torch.distributed.optim.zero_redundancy_optimizer import (
    ZeroRedundancyOptimizer,
)


def main() -> None:
    dist.init_process_group("nccl")
    device = torch.device(dist.get_rank())
    model = nn.parallel.DistributedDataParallel(resnet18().to(device))
    loss_func = nn.CrossEntropyLoss()
    optimizer = ZeroRedundancyOptimizer(model.parameters(), Adam)
    samples = torch.randn(1, 3, 64, 64).to(device)
    targets = torch.randint(1000, (1,)).to(device)
    outputs = model(samples)
    loss = loss_func(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(True)
    optimizer.consolidate_state_dict(dist.get_rank())  # will hang
    torch.save({"optimizer": optimizer.state_dict(), }, "test.pth")

if __name__ == "__main__":
    main()

cc @rvarm1 @agu about ZERO optimizer

The argument to consolidate_state_dict() is the target rank (which is 0 by default). If you want to save the unsharded state dict on all ranks, then you need to loop over all ranks as the target:

for rank in dist.get_world_size():
    optimizer.consolidate_state_dict(rank)
torch.save({"optimizer": optimizer.state_dict(), }, "test.pth")

However, that will redundantly save the same state dict from all ranks. If that is not desired, you can do something like:

optimizer.consolidate_state_dict()  # default to rank 0
if dist.get_rank() == 0:
    torch.save({"optimizer": optimizer.state_dict(), }, "test.pth")

It still hang:

r"""Run ``torchrun --nproc_per_node=2 test.py``
"""
import torch
from torch import distributed as dist
from torchvision.models import resnet18
from torch import nn
from torch.optim import Adam
from torch.distributed.optim.zero_redundancy_optimizer import (
    ZeroRedundancyOptimizer,
)


def main() -> None:
    dist.init_process_group("nccl")
    device = torch.device(dist.get_rank())
    model = nn.parallel.DistributedDataParallel(resnet18().to(device))
    loss_func = nn.CrossEntropyLoss()
    optimizer = ZeroRedundancyOptimizer(model.parameters(), Adam)
    samples = torch.randn(1, 3, 64, 64).to(device)
    targets = torch.randint(1000, (1,)).to(device)
    outputs = model(samples)
    loss = loss_func(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(True)
    for rank in range(dist.get_world_size()):
        optimizer.consolidate_state_dict(rank)
    if dist.get_rank() == 0:
        torch.save({"optimizer": optimizer.state_dict()}, "test.pth")

if __name__ == "__main__":
    main()

My bad. Every time you call optimizer.consolidate_state_dict(rank), it clears the rank’s previously consolidated state dicts, so the loop for rank in range(dist.get_world_size()) was actually clearing rank 0’s consolidated state dict.

This should work (I checked this one locally):

optimizer.consolidate_state_dict(0)
if dist.get_rank() == 0:
    torch.save(...)

Or, if you want to save on all ranks, then you should call torch.save() in the loop (I did not check this one):

for rank in range(dist.get_world_size()):
    optimizer.consolidate_state_dict(rank)
    if dist.get_rank() == rank:
        torch.save(...)
1 Like