Loading Weights From saved model


In my case I have a saved model named best_model.pkl, which inherently contains multiple models belonging to pose, depth and other models. I was looking to use the existing best_model.pkl to extract the individual model and save them as separate .pth files.

I am doing something on this line:

model_to_save = torch.load('best_model.pkl')
print (model_to_save.keys())

which gives the following output:

dict_keys(['epoch', 'model_state', 'optimizer_state', 'scheduler_state', 'best_iou'])

Of which I select the model_state, which has each of the individual models. Now I am looking to save each of the individual models by the following code.

models = ["depth", "pose_encoder", "pose", "encoder"]
    for model_name in models:
        save_path = os.path.join("{}.pth".format(model_name))
        to_save = model_to_save['model_state']['models'][model_name].state_dict()
        torch.save(to_save, save_path)

But I receive following error:

  File "save.py", line 10, in save_monodepth_models
    to_save = model_to_save['model_state']['models'][model_name].state_dict()
KeyError: 'models'

I see that type(model_to_save['model_state']) is an <class 'collections.OrderedDict'> and keys inside the model_to_save['model_state'] is:

odict_keys(['models.encoder.encoder.conv1.weight', 'models.encoder.encoder.bn1.weight', 'models.encoder.encoder.bn1.bias', 'models.encoder.encoder.bn1.running_mean', 'models.enco
der.encoder.bn1.running_var', 'models.encoder.encoder.bn1.num_batches_tracked', 'models.encoder.encoder.layer1.0.conv1.weight', 'models.encoder.encoder.layer1.0.bn1.weight', 'mod
els.encoder.encoder.layer1.0.bn1.bias', 'models.encoder.encoder.layer1.0.bn1.running_mean', 'models.encoder.encoder.layer1.0.bn1.running_var', 'models.encoder.encoder.layer1.0.bn
1.num_batches_tracked', 'models.encoder.encoder.layer1.0.conv2.weight', 'models.encoder.encoder.layer1.0.bn2.weight', 'models.encoder.encoder.layer1.0.bn2.bias', 'models.encoder.
encoder.layer1.0.bn2.running_mean', 'models.encoder.encoder.layer1.0.bn2.running_var', 'models.encoder.encoder.layer1.0.bn2.num_batches_tracked', 'models.encoder.encoder.layer1.0
.conv3.weight', 'models.encoder.encoder.layer1.0.bn3.weight', 'models.encoder.encoder.layer1.0.bn3.bias', 'models.encoder.encoder.layer1.0.bn3.running_mean', 'models.encoder.enco

Not exactly sure how should I handle this case. Any suggestion would be helpful!


It seems that model_to_save['model_state'] contains a state_dict, which is the recommended way to serialize the models.
To restore the model, create a model instance first and load the corresponding state_dict afterwards.
I would not recommend to save the model directly via torch.save(model) but to stick to saving the state_dict via torch.save(model.state_dict()) (or in your case you could also wrap it into more dicts if needed) as the former might easily break (e.g. you would need to restore the same file structure etc.).

1 Like