Switching between mixed-precision training and full-precision training after training is started

Hi, using apex.amp or torch amp, is it possible that I switch between mixed precision training and full precision training after the training is started? For example, I might want the first 100 iterations to be trained with full precision, and switch to the mixed precision mode after the 100 iterations. Is this possible?

Yes, you could use the enabled arguments in torch.cuda.amp.autocast as well as in the creation of the GradScaler and switch them to True/False in your training run.

1 Like

A quick question, during training when there’re frequent changes, say per 150 iterations, between mixed precision mode and the full precision mode (i.e. switching precision), why the accuracy gaps between purely full precision and the switching precision is so large? The accuracy gap between purely full precision and the purely mixed precision is not that large.

Also I’m not sure how to set GradScaler’s enabled=False dynamically. Isn’t that started before the training, and you can’t change it afterwards?

I don’t know what might be causing this effect as I haven’t seen or ran any experiments by frequently changing the precision.

Yes, usually you could either enable or disable amp in your script which can be done with the enabled arguments. If you want to dynamically switch the precision, you could use if conditions to pick the appropriate code path.

So for example, in the first 150 iterations I want full precision training, and in the later 150 iterations I want to switch back to mixed precision, and switching like this so on and so forth. The scaler is set before the training begin, i.e. scaler = torch.cuda.amp.GradScaler(enabled=use_amp), then do you mean that during the training, say after 150 iterations, I need to re-initialize this scaler and change the enabled to true or false?

No, just skip the scaler usage if amp is disabled:

if use_amp:
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
else:
    loss.backward()
    ...
1 Like

Thanks for the response! Now the torch amp can run correctly. Another quick question: how can I switch between Apex amp O0 and O1/O2/O3 during training time? Is model checkpointing and reloading the only way to accomplish this, since the model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) is set before the training begins?

apex.amp is deprecated and you should use the native mixed precision util. via torch.cuda.amp.