Load weights from trained models for intialization

Hi,
I used a resnet50 model(pretrained=True) for training. I saved the best model in training function like:

model.load_state_dict(best_model_wts)
return model

then i called my training function:

trained_model = training_func(.....)
torch.save(trained_model, 'trained.pth')

Now I want to train again using the weights of my trained model. So what I did is:

pretrained_weights = torch.load('trained.pth'')
model = resnet50(pretrained=False)
model.load_state_dict(pretrained_weights)

But it throws error:

model.load_state_dict(pretrained_weights)
  File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 751, in load_state_dict
    state_dict = state_dict.copy()
  File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 539, in __getattr__
    type(self).__name__, name))
AttributeError: 'ResNet' object has no attribute 'copy'

What I am doing wrong, can you please tell me ? I want to use my trained model’s weights to initialize my new training, same model as before resnet50.

This should work !

trained_model = training_func(.....)
torch.save(trained_model.state_dict(), 'trained.pth')

There are two ways of saving and loading models in Pytorch. You can either save/load the whole python class, architecture, weights or only the weights.

It is explained here

In your case, you can load it using.

model = torch.load('trained.pth')

when training:

trained_model = training_func(.....)
torch.save(trained_model.state_dict(), 'trained.pth')

then:

pretrained_weights = torch.load('trained.pth'')
model = resnet50(pretrained=False)
model.load_state_dict(pretrained_weights)

you save your model state_dict with model structure by using torch.save(trained_model, 'trained.pth') , but you just load state_dict by model.load_state_dict(pretrained_weights)

2 Likes

Thanks, but , if you look , what I am returning is

trained_model.state_dict()

So , no difference between your suggestion and my implementation , is it ?

torch.save(trained_model, 'trained.pth')

if this is your return, then you are returning the whole model

model.load_state_dict(best_model_wts)
return model

I see, that does make sense!! Thanks.

Just for others,
Depending on the saving options, my code also works. Just in below way:

pretrained_model = torch.load('trainedModel.pth')
model = resnet50(pretrained=False)
model.load_state_dict(pretrained_model.state_dict(),strict=False)

Yes it works this way ! but with two drawbacks:

  • now you are saving large files ( memory issue)
  • now you are saving as a pickle object which doesn’t save the model’s classes, but saves paths to them. As mentioned in the documentation, this may cause a problem during load time.