About torch.cuda.amp model resuming

Hi,

I have some questions about using torch.cuda.amp() related resuming models.

  1. For example, after I trained a model without torch.cuda.amp(), then saved the model weights and optimizer using torch.save() function. If I resumed a model with torch.cuda.amp() function running, I think there is an error in training the model. Am I right?

  2. If I want to train properly when resuming, should it be same circumstance on using torch.cuda.amp()??

For example, is below okay?

Train using amp → resume → Train using amp
or
Don’t train using amp → resume → Don’t train using amp

Thanks.

  1. No, there shouldn’t be an error and your use case would be comparable to using torch.cuda.amp for fine tuning a pre-trained model.

  2. The model parameters are all stored on float32 so you could continue the training either using amp or the standard float32 precision.