Using a pytorch model for inference

Hi Vishal,

It is not recommended to save the entire model (architecture and weights) the way that you did because that method will not work if you try to load the model in a different project. For example, if you try to send your model file ./torch_model_v1 to me and I try to load it with torch.load("./torch_model_v1") I will get an error because it’s likely my project won’t have the exact same directory structure as your project.

Instead you should save only the model weights (state dict), define the architecture in code, then load the weights into the new models state dict.

## save 
torch.save(model.state_dict(), "./torch_model_v1.pt")

## load
model = Model() # the model should be defined with the same code you used to create the trained model
state_dict = torch.load( "./torch_model_v1.pt")
model.load_state_dict(state_dict)

6 Likes