Do I need to save the state_dict oof GradScaler?

Do I need to save the state_dict of torch.cuda.amp.GradScaler and reload it to resume training? The docs say it dynamically estimates the scale factor each iteration, so I never saved it. So, will model.load_state_dict and optimizer.load_state_dict suffice?

1 Like

If you want to restore the last scale factor (as well as the backoff and growth factor, if changed), then you should restore its state_dict.
Your training should also work without restoring the gradient scaler, but will most likely not reproduce the same results as a run without interruptions, as the new gradient scaler could skip iterations at different steps.

Hi @ptrblck Recently I encountered an interesting thing. When I trained my model, I didn’t save states of GradScaler, but saved optimizer states. Later I found that if I resumed training by loading the checkpoint, the performance was higher than continuing the training. The only difference I spot is that the GradScaler is newly created in the resumed session. So is it possible that if we recreate/reset GradScaler at each epoch, the performance would be actually higher than using the old GradScaler? Thanks.

I wouldn’t expect to see a difference in performance while using a new GradScaler.
In case you are recreating a new GradScaler, the first iteration(s) might have been skipped due to a high scaling factor, so your runs might be a bit different. You could thus check, if resuming the training in FP32 (i.e. without the GradScaler) and if resuming with the GradScaler` + skipping the same number of update steps would also yield a different performance.

Wow you replied so quickly. Thank you @ptrblck ! What if the first few iterations are not skipped, can I say that using a new scaler is equivalent to increasing the learning rate in the first few iterations? Seems there’s no warning message saying the gradients are scaled in the first few iterations…

No, the GradScaler would scale the loss, thus also the gradients, but the gradients would be unscaled before the optimizer.step() operation is performed. The loss scaling factor also has a growth_factor and growth_interval which would increase it after growth_interval steps were performed without a decrease by growth_factor. In case the scaling factor is low right before saving the checkpoint it could potentially benefit from a large factor again.