Loading state from file

According to the tutorials

one has to define a model and an optimizer first and then load the states from file. But I am wondering, how I am to initialize for instance the optimizer if I only have access to the parameters later?

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

thank you

I think you are mistaking parameters and state variables.

On the second line, where you initialize the optimizer, you give it initialization parameters that define that instance (for example, learning rate or momentum).

When you are loading the state dict, however, those are (in the case of the optimizer) the gradients computed on the last forward pass. They are required to continue training from a checkpoint correctly.

Note that you could save any variable in the checkpoint when calling torch.save, like the current learning rate, so that you could initialize the optimizer correctly in the first place.

Also note that the optimizer needs the model parameters to make learning happen. Not to be confused with the initialization parameters of the optimizer itself, but you will have a params argument in the constructor of the optimizer for this.

1 Like

thank you for your reply, maybe I should clarify. everything you say is true but I was wondering, if I am supposed to pass in model.parameters() although they are only randomly initialized weights at that point and some default lr only to be overwritten later when the actual state is loaded… Seems cumbersome

You should pass the parameters of the model that you are creating. The parameters of this model and the state of the optimizer will be over-written upon loading the checkpoint.

Note that If you have not saved the state of the optimizer, then you cannot recover the state of the optimizer, since for some optimizers like Adam, the state varies by the gradients of your model during the training.

ya I know =) Thank you. Just bugs me having to put some random values that are overwritten shortly after

Sure! Note that the purpose of passing model.parameters() to the optimizer is to tell the optimizer which parameters are going to be optimized.

It’s not really passing the weights themselves that’s important, it’s having a reference to the weights later on during the training! So that the optimizer can compute the gradients and so on.

I know that, it just would make more sense here to either construct the model and optimizer directly with loaded values or being able to default construct them in the first place…

As @alex.veuthey has indicated, passing model.parameters() is a reference to the parameters that needs to be learned, the initial values do not matter. Creating an optimizer needs this reference, otherwise, having an optimizer is point-less.

So, the steps are as follows:

  1. Create an initial model
  2. Create an optimizer that have reference to the parameters of the model that needs to be trained by this optimizer
  3. Load the state of the model and the state of the optimizer to resume training
1 Like