Model training with automatic mixed precision is not learning

Using mix precision, the loss flattens out after the first few iterations. The model trains fine when mix precision is not used.

Here is an example of how it’s implemented. I am using a ctc loss function

for i, _data in enumerate(train_loader):
    spectrograms, labels, input_lengths, label_lengths = _data 
    spectrograms, labels = spectrograms.cuda(), labels.cuda()
    with autocast():
        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)

        loss = ctc_loss(output, labels, input_lengths, label_lengths)
        loss = loss / args.grad_acc_steps
        scaler.scale(loss).backward()

    if i % args.grad_acc_steps == 0:
        scaler.unscale_(optimizer)  # unscale to clip gradient
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

Try moving scaler.scale(loss).backward() out of the autocast context

    with autocast():
        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)

        loss = ctc_loss(output, labels, input_lengths, label_lengths)
        loss = loss / args.grad_acc_steps
    scaler.scale(loss).backward()
2 Likes

Also, what does “a few iterations” mean here? It’s expected that scaler.step(optimizer) may skip the first few steps due to inf/nan gradients as the scale value calibrates, so the loss would not decrease for those iterations.

1 Like

Thanks, @mcarilli. A few iterations in this context is about 1000 iterations.

And you’re right scaler.scale(loss).backward() should be outside the autocast context. The actual fix was due to how PyTorch did dynamic scaling. For my specific use-case, I had to set the growth_interval parameter to something smaller like 10 iterations. The issue was the scale value was not high enough causing a flat loss. Once the model started learning I’ve reset the growth_interval to the default value 2000 iterations.

my updated code is as follows.


scaler = GradScaler(growth_interval=10)
for i, _data in enumerate(train_loader):
    spectrograms, labels, input_lengths, label_lengths = _data 
    spectrograms, labels = spectrograms.cuda(), labels.cuda()
    with autocast():
        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)

        loss = ctc_loss(output, labels, input_lengths, label_lengths)
        loss = loss / args.grad_acc_steps
    scaler.scale(loss).backward()

    if i % args.grad_acc_steps == 0:
        scaler.unscale_(optimizer)  # unscale to clip gradient
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
   if scaler.get_growth_interval() != 2000 and total_iter > 1000:
        scaler.set_growth_interval(2000)

This solution works for me and hopefully it can help someone who runs into the same issue.

3 Likes

Glad it’s working, and an interesting discovery. However, I don’t see the issue as solved yet. I think we can make it work better for your model immediately, and also help prevent this issue for future users.

In our experience GradScaler's default constructor values rarely need to be tuned. Yours is the first case I’m aware of with the native API, and we tried it with 40ish models spanning many applications before merging. The intention was to supply default values (and a dynamic scale-finding heuristic) that are effective for the vast majority of networks, so GradScaler’s args don’t become additional “hyperparameters.”

The default init_scale is intended to be larger than the network initially needs. The large initial value causes inf/nan gradients for the first few iterations but quickly calibrates down to a successful value (because it’s reduced by backoff_factor each time). After that, the large growth_interval means few iterations should be skipped, and the effect on performance is negligible.

In your case, it appears you’re in the opposite situation: the default init_scale is smaller than you initially need. growth_interval=10 is one way to increase the scale more quickly than it otherwise would, but once the value calibrates/stabilizes, roughly 1 out of 10 iterations will be skipped (a 10% training slowdown). You work around this by resetting growth_interval later, which is smart, but also inconvenient and not obvious. If all you need is a higher initial value, I’d construct GradScaler(init_scale=<bigger value>) instead of playing with growth_interval in multiple places.

Per above paragraphs, the best practice is to supply an init_scale that’s larger than your network needs, so the scale quickly calibrates down, then stabilizes. To do that, we need to figure out the value it calibrates to. Can you rerun your existing code (with growth_interval=10) and print scaler.get_scale() just after scaler.update() for the first few dozen steps to get a sense for the scale value it finds, and post the results here?

For you, the best init_scale would then be the next-greatest power of two* above the value it finds, and you can then ignore growth_interval. For me, the value it finds justifies a PR to increase the default init_scale, reducing the likelihood of this issue in the future. A larger initial scale value doesn’t do much harm for any network (worst case, it causes a few more iterations at the beginning to be skipped).

(*Powers of two are best for for init_scale, growth_factor, and backoff_factor because multiplication/division by powers of two is a bitwise accurate operation on non-denormal IEEE floats.)

1 Like

@mcarilli thanks for the writeup.

Can you rerun your existing code (with growth_interval=10 ) and print scaler.get_scale() just after scaler.update() for the first few dozen steps to get a sense for the scale value it finds, and post the results here?

So the results for scaler.get_scale() is as follows

iter 1 - 24: goes from 65536.0 to 0.00390625
iter 24 - 256: goes from 0.00390625 to 32768.0
iter 256 - 1000: bounces around from 1024 - 16384.0 but stabilizes at around 1024 near 1000.
iter 1000 - 50000: bounces around from 256 - 2048.0 but from eye balling it, the 1024 seems to be what occurs most frequently.

It seems like my network needs a range of scale values in the beginning to start stabilizing, then bounces around once the model starts learning.

That’s wild. I’ve never seen behavior like that before. It doesn’t just decrease to a stable value, it goes way down, then bounces back up…The growth_interval manipulation may be the best approach, so it aggressively tries many new values in the beginning.

Does the full run converge to roughly the same accuracy as FP32?

1 Like

Yup the model is converging to similar accuracy as FP32! Thanks for the help.

Which learning rate scheduler are you using?
Could you remove the scheduler and check, if you are still seeing this shaky behavior?

Aren’t you missing optimizer.zero_grad() in the loop?

yes i am in the example code, but i have in my actual code. Thanks for pointing it out!

I’m using the one cycle scheduler! Currently moved on from the issue but this could also be an issue so I want to make a statement for others to be aware.