Hi all,
I am trying to save the model in PyTorch by using the below code:
model=utils.get_model(self.model)
torch.save({#‘model_state_dict’: model,
#added new
‘model_state_dict’: model.state_dict(),
}, os.path.join(self.checkpoint, ‘model_{}.pth’.format(task_id)))
I am able to load the model successfully with no issues in my app. The model is been saved in to a pth file.
My second step is to take the saved model model.pth and load it via the code below into another application:
model.load_state_dict(torch.load(“./checkpoint/model.pth”))
It is giving me the below error:
RuntimeError Traceback (most recent call last)
in
----> 1 model.load_state_dict(torch.load(“/home/jovyan/.cache/torch/checkpoints/resnext50_32x4d-7cdf4587.pth”))
2 model = model.eval()/srv/conda/envs/notebook/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
828 if len(error_msgs) > 0:
829 raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
→ 830 self.class.name, “\n\t”.join(error_msgs)))
831 return _IncompatibleKeys(missing_keys, unexpected_keys)
832RuntimeError: Error(s) in loading state_dict for Target:
Missing key(s) in state_dict: “conv_layer.0.weight”, “conv_layer.0.bias”, “conv_layer.1.weight”, “conv_layer.1.bias”, “conv_layer.4.weight”, “conv_layer.4.bias”, “conv_layer.5.weight”, “conv_layer.5.bias”, “conv_layer.7.weight”, “conv_layer.7.bias”, “conv_layer.8.weight”, “conv_layer.8.bias”, “conv_layer.11.weight”, “conv_layer.11.bias”, “conv_layer.12.weight”, “conv_layer.12.bias”, “conv_layer.14.weight”, “conv_layer.14.bias”, “conv_layer.15.weight”, “conv_layer.15.bias”, “conv_layer.18.weight”, “conv_layer.18.bias”, “conv_layer.19.weight”, “conv_layer.19.bias”, “conv_layer.21.weight”, “conv_layer.21.bias”, “conv_layer.22.weight”, “conv_layer.22.bias”, “conv_layer.24.weight”, “conv_layer.24.bias”
Therefore in my code I start to explore additional options to add the model_state too,
My question is, isn’t supposed once I use the below code save all my model including the model state?
model=utils.get_model(self.model)
torch.save({
Apparently not, that’s why I added the below to my code:
torch.save(model.state_dict(), os.path.join(self.checkpoint, ‘model_state_{}.pth’.format(task_id)))
However, I am getting this error
in save_all_models
‘model_state_dict’: model.state_dict(),
AttributeError: ‘collections.OrderedDict’ object has no attribute ‘state_dict’
Thank you for your help in advance.