About using GradScaler() for not only one model in a training progress

Hi there. For saving the GPU memory, I use FP16 in my work just like nnUNet did, they defined a GradScaler() for updating the gradients and stuff. But I modified the code for my own usage:

For updating loss and optimizer for twice.

  • Compared with the original update order like optimizer1.step(), and optimizer2.step() followed.What’s the difference? Can it work normally as we want?

And here, with the example provided Pytorch[AUTOMATIC MIXED PRECISION PACKAGE - TORCH.AMP] as below:

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()

for update optimizer1 and optimizer2, I change the code like this:

scaler.scale(loss).backward()
scaler.unscale_(optimizer1)
scaler.unscale_(optimizer2)  # new
torch.nn.utils.clip_grad_norm_(model1.parameters(), max_norm)
torch.nn.utils.clip_grad_norm_(model2.parameters(), max_norm)  # new
scaler.step(optimizer1)
scaler.step(optimizer2)  # new
scaler.update()
  • But cause such error:
AssertionError: No inf checks were recorded for this optimizer.

The example for loss is in : For loss update twice in Gradscaler.scale(loss).backward()


How can I run FP16 with Gradscaler for updating optimizer/loss for twice?
Thank you!

Or i should just defined two different scaler, but how can I update my loss?

Your general code works correctly as seen here:

model1 = nn.Linear(1, 1).cuda()
model2 = nn.Linear(1, 1).cuda()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-3)

scaler = torch.cuda.amp.GradScaler()

for _ in range(10):
    x = torch.randn(1, 1).cuda()
    with torch.cuda.amp.autocast():
        out = model1(x)
        out = model2(out)
        loss = out.mean()
    
    
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer1)
    scaler.unscale_(optimizer2)  # new
    max_norm = 1.
    torch.nn.utils.clip_grad_norm_(model1.parameters(), max_norm)
    torch.nn.utils.clip_grad_norm_(model2.parameters(), max_norm)  # new
    scaler.step(optimizer1)
    scaler.step(optimizer2)  # new
    scaler.update()

and I cannot reproduce the reported error.

1 Like

I found your code is work for me, too, when I tried it in the console. But can’t still fix the AssertionError, and I know that it seems work for GradScaler().

But reported for the other bug not related with it. Now I n-checked my code and I got what’s wrong and no errors report.

Thanks a lot! :smiley:

Good to hear you’ve found the issue! Would you mind explaining what went wrong in your code as I would be interested what has raised the error?

It sounds dumb :smiling_face_with_tear: , for the mean-teacher framework, I am training my work with loading weigths for teacher model from trained model, and I forgot to remove the teacher model from the part of torch.no_grad(), so there is not any gradients feed back, might be why this AssertionError cause.
After fixed this, no errors reported anymore.

1 Like