Resume training with mixed precision lead to "No inf checks were recorded for this optimizer"

I’m using mixed half precision with torch.cuda.amp.GradScaler() for the training of my model/
I try to do a training pipeline where I can stop/resume any training.
For that, at each epoch I save:

  • model state_dict
  • optimizer state_dict
  • scaler state_dict (scaler = torch.cuda.amp.GradScaler())

To resume the training I load the above state dict for model, optimizer and scaler. The following code highligth the part of the optimizer state dict loading:

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])
# I try to move optimizer state to cuda but same error with or without
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

When I try to resume training, everything seems to be loaded correctly but at the first iteration when I tried to step the scaler in my backpropagtion I have the following error:

  File ".../base_engine.py", line 97, in backprop_loss
    self.scaler.step(self.optimizer)
  File ".../site-packages/torch/cuda/amp/grad_scaler.py", line 318, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

However when I disable scaler when I resume training (with torch.cuda.amp.autocast(enabled=False)) it seems to work correctly.

I also tried different optimizer like pytorch SGD, Adam, but also more custom implementation like RAdam or SGDP.

Any idea how I could solve this?

Are you seeing the same issue without moving the states to another device?
Could you also explain, why this code is needed?

Yes, same issue without this part.
In fact, I didn’t had the moving states part at first, I just try to it see if that could fix the error but it doesn’t…

I cannot reproduce this issue using:

model = models.resnet18().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(10):
    optimizer.zero_grad()
    
    with torch.cuda.amp.autocast():
        out = model(torch.randn(1, 3, 224, 224).cuda())
        loss = out.mean()
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scaler': scaler.state_dict()
}
torch.save(checkpoint, 'tmp.pt')

# load
cp = torch.load('tmp.pt')
model = models.resnet18().cuda()
model.load_state_dict(cp['model'])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer.load_state_dict(cp['optimizer'])
scaler = torch.cuda.amp.GradScaler()
scaler.load_state_dict(cp['scaler'])

for epoch in range(10):
    optimizer.zero_grad()
    
    with torch.cuda.amp.autocast():
        out = model(torch.randn(1, 3, 224, 224).cuda())
        loss = out.mean()
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

Could you post an executable code snippet, which would show this error?