ZeroRedundancyOptimizer consolidate_state_dict warning

I have been trying to minimize my memory footprint while training with DDP and came across ZeroRedundancyOptimizer. I implemented it similarly to that from the example, by simply wrapping an AdamW class into it. This is the (rough) example of the training loop I am using that is called from torch.multiprocessing.spawn.

device = f'cuda:{gpu}'
torch.cuda.set_device(gpu)  # From 0 to N-1 for the number of GPUs

net = net.to(device)
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank],
                                                gradient_as_bucket_view=True,
                                                find_unused_parameters=True)
scaler = GradScaler()
optimizer = ZeroRedundancyOptimizer(net.parameters(), 
                                    optim=torch.optim.AdamW, lr=learning_rate,)

for sample, ground_truth in train_loader:
    optimizer.zero_grad(set_to_none=True)
    
    with autocast(): 
        pred = net(sample)
        loss = loss_fn(pred, ground_truth)
   
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

However, I am getting the error Optimizer state has not been consolidated. Returning the local state Please call "consolidate_state_dict()" beforehand if you meant to save the global state.

Am I doing something wrong with this implementation?

I am running torch == 1.8.1 and cudatoolkit==10.2

cc @mrshenli @rvarm1 if you have any insights on what might be wrong there?

I’m guessing that some part of the autocast/grad scaling is attempting to call optimizer.state_dict but ZeRo requires state dict to be consolidated first. This seems like it might be a limitation in terms of autocast and ZeRo working together.

@ kleingeo Could you please file an issue with a full repro to the PyTorch GH repo? Thanks!

I found the error. It was stemming from the fact that I was saving the optimizer after each epoch without calling optimizer.consolidate_state_dict() first. After I included that before saving the warning went away.