Saving model AND optimiser AND scheduler

Hi,
I want to able to have a model/optimiser/scheduler object - which I can hot plug and play.
So for example, have a list of such objects, load to gpu in turn, do some training, switch objects.
Maybe then load some earlier ones and pick up training where we left off last time.
I’d like to be able to easily (deep) copy these objects, and save/load to disk.
Note - some models or optimisers or schedulers may be different in these different objects.

One idea - use the torch.save(model) - this will pickle the model class and reproduce the object and load the state_dict, but will it restore the scheduler/optimiser?
Can I use torch.save with these as well to reproduce them later?
How do I control which model they are attached to?

And how can I duplicate the whole model/optimizer/scheduler whilst in memory?

Any suggestions on the best way to do this please?
Many thanks in advance!

1 Like

You can create a dictionary with everything you need and save it using torch.save(). Example:

checkpoint = { 
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'lr_sched': lr_sched}
torch.save(checkpoint, 'checkpoint.pth')

Then you can load the checkpoint doing checkpoint = torch.load('checkpoint.pth')
More info here: Loading a saved model for continue training

2 Likes

Hi, many thanks for your quick reply! :slight_smile:
So I’ve seen that. But the problem with that is it only copies the models state dict.
If different objects say have different types of models, how do I know which model to create before giving it the saved state dict?
Same for optimizers and schedulers.

Then you can save the model itself, i.e. without calling state_dict(). This will dump the whole content of the variable model into a pickle file, which leads to a larger file than in the previous case. The same applies to the optimizer. The scheduler does not have a state_dict(), so the whole variable is saved anyways.

So I save the model. And then the optimizer. They will be in two separate files? Will they be connected to each other correctly on reloading. And if I want to say create 10 duplicates of this “group” is there a way to do that?
Many thanks :smiley:

They will be in two separate files?

You can save them in separate files or wrap them in a dictionary like the one I showed (where you remove the calls to state_dict()).

Will they be connected to each other correctly on reloading.

There is no loss of information when you save the whole variable instead of saving the state_dict() only, so I assume that the model and the optimizer will remain “connected” after you reload them (although I’ve never tried it).

And if I want to say create 10 duplicates of this “group” is there a way to do that?

I’m not sure if I understand what you mean with a “duplicate”… Do you mean, for instance, saving a checkpoint of the training after every epoch? You may save each checkpoint file with a different name and that’s it.

If I store them in the same file - it only stores the state_dicts for each. It will not pickle the object. Problems then if different “objects” have different models.
Separate files sounds troublesome. Which order do I load them in and do they still work for training.
Imagine one of a group of 10 model/optimizer/scheduler does particularly well after round one - where each has had an hour on the gpu. This may be because I gave it a favourable set of hyper parameters.
I want to take this group, duplicate it x10, and then run each of those with slightly different hyperparameters.
In essence - I want to be able to duplicate the whole model/optimiser/scheduler state.

If I store them in the same file - it only stores the state_dicts for each. It will not pickle the object.

This is not true. As long as you do not call state_dict(), it will save the whole variable. Please try the following.

For saving:

checkpoint = { 
    'epoch': epoch,
    'model': model,
    'optimizer': optimizer,
    'lr_sched': lr_sched}
torch.save(checkpoint, 'checkpoint.pth')

For loading:

checkpoint = torch.load('checkpoint.pth')
epoch = checkpoint['epoch']
model = checkpoint['model']
optimizer = checkpoint['optimizer']
lr_sched = checkpoint['lr_sched']

What’s the problem with this approach?

5 Likes

Oh I see what you mean. I’m sorry - me stupid! :crazy_face:
Your help is much appreciated! :grin:

1 Like

Oh, sorry - one last thing - wondered about copying the whole thing say x10?
Do I just reload the same thing from disk multiple times?

I’ve been working on a checkpoint helper for such use case. It is still WIP but check it out: https://pypi.org/project/pytorchcheckpoint/

2 Likes

You can reload the same pickle file as many times as you need.

1 Like