Comparison between saving the whole model, saving only state_dict and torchscript

I’m trying to figure out what’s the best way to save a model trained with Pytorch and load it for inference, and I was wondering about the different possible approaches.

Let’s say I successfully train a model, as far as I understand I can use:

  • Complete Model Saving:
# save the model
torch.save(model, saved_model_path)
# load model directly with
loaded_model = torch.load(saved_model_path)
# use it for inference
output = loaded_model(input)
  • State Dict Saving:
# save only the state_dict after training
torch.save(model.state_dict(), saved_model_path)
# need to create an instance of the model with the same architecture and then load the parameters using
model = SomeModelConstructor()
model.load_state_dict(state_dict)
# use it for inference
output = loaded_model(input)

And I seem to understand saving only the state_dict has the advantage of reducing the file size, while the con is that I have to recreate the model instance. Is that correct?

Also, I see I can also export the model to torchscript and load it:

# export to torchscript and save locally
model_scripted = torch.jit.script(model)
torch.jit.save(model_scripted, scripted_model_path)

# load the scripted model
loaded_scripted_model = torch.jit.load(scripted_model_path)

# use for evaluation
loaded_scripted_model.eval()
with torch.no_grad():    
    output = loaded_scripted_model(input)

Assuming I’m exporting the model to script from Python and loading it back always from Python, are there any advantages in using the scripted model instead of the saved one (for example, inference speed because the scripted version is better optimized or something similar?)

Yes, you would need to recreate the model instance when saving the state_dict only, but note that saving the entire model requires you to exactly recreate the same file structure (including the source code of the model) on the target system, which is quite brittle and you will this see a lot of topics in this discussion board of users asking how to load the entire model now after moving files around or changing function signatures as it fails.

Saving and loading the state_dict is thus the recommended approach.

Scripted models can be optimized by TorchScript and loaded into the C++ libtorch runtime directly. However, TorchScript is now in maintenance mode and I would not depend on it anymore.

1 Like

Thank you very much for your reply! :slight_smile:

Just a couple of other question:

Scripted models can be optimized by TorchScript and loaded into the C++ libtorch runtime directly

So, as I was mentioning, if I simply load back the scripted model from Python and use it for inference, there is no such optimization?

However, TorchScript is now in maintenance mode and I would not depend on it anymore.

Could you please elaborate on this? Torchscript will be discontinued?

Yes, that’s my understanding as TorchDynamo will be developed instead.

TorchScript would also optimize the model in your Python script, but I would generally not recommend relying on it anymore. Use torch.compile or the beta torch.export support and try to run your model.