[DDP + amp + gradient accumulation] Calling optimizer.step() leads to NaN loss

I want to add gradient accumulation feature to my DDP + amp training program for constant batch size when training large models. The model is wrapped by torch.nn.parallel.DistributedDataParallel.

Refer to

  1. 2nd case of albanD’s reply;
  2. pytorch document of amp working with gradient accumulation,

I implemented my code like

optimizer = torch.optim.AdamW(model.parameters(), args.lr,weight_decay=args.weight_decay)
for i, input in enumerate(train_loader):
    with torch.cuda.amp.autocast(True):
        loss = model(input) / iters_to_accumulate   # Normalize our loss (if averaged). See https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3#file-gradient_accumulation-py-L5.
    if (i+1) % iters_to_accumulate == 0:                # Wait for several backward steps
        optimizer.step()                            # Now we can do an optimizer step
        optimizer.zero_grad()                           # Reset gradients tensors only if we have done a step. See https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3?permalink_comment_id=2921188#gistcomment-2921188. And https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation.

The losses turn out to be NaN since second iteration (when i==1).
But When I comment out line optimizer.step(), everything works like a charm.
Later, I found that in both
BramVanroy’s thread, and
pytorch document of amp working with gradient accumulation
optimizer.step() is not called intentionally.

My question is, should optimizer.step() be called in my case?
If yes, how to solve the NaN loss problem.
If not, why? and would it influence the performance of the model?

No, it shouldn’t as given in the amp tutorial. The GradScaler.step function calls optimizer.step() internally after unscaling the gradients and making sure no Infs/NaNs were found.

Thank you.

For those who may encounter the same issue, Automatic Mixed Precision package - torch.amp#torch.cuda.amp.GradScaler.step says

  1. If no inf/NaN gradients are found, invokes optimizer.step() using the unscaled gradients. Otherwise, optimizer.step() is skipped to avoid corrupting the params.

It should not be called manually.