When storing a trained model via
torch.save(model, path) the module definition is not stored within the saved file.
This leads to problems when I want to load the pre-trained model from other projects (or with global command line tools). I.e. leading to:
ModuleNotFoundError: No module named 'ResNet' when loading with
Can I store the model definition (i.e. ResNet.py) within the model itself? If yes, how?
It just says:
However, in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
So no, unfortunately, this doesn’t help as this is exactly what I described and what I don’t want.
From the documentation you see that the PyTorch developers have given serialization quite a bit of thought and arrived at the conclusion that - as you discovered yourself - it is too fragile. That is why they adopted the approach of providing factory functions e.g. in torchvision.
So what you would need to do to follow that best practice is to put your models in a module that Python then can load. This isn’t PyTorch specific, but a general limitation of Python serialisation.
In the meantime, you could see if JITed modules work for you. They intentionally store the parameters and computation in a file so you can use it without further dependencies. The obvious drawback is that - to work around Python’s serialisation limitations - Python is left out of the loop, so the saved model is not the same as a model loaded from a Python module.