How to load using torch.load without source class (using which model was created)?

Hi there, in first file I’m defining the model class as “Classifier” and training the model and then saving it using torch.save(model, 'model.pt', dill). Code in first script looks like-

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # a toy example model
        self.model = nn.Sequential(
            nn.Linear(10, 5),
            nn.LeakyReLU(0.2),
            nn.Linear(5, 10))
    
    def forward(self, x):
        return self.model(x)

    def a_use_full_function(self):
        """
        A function I'm using while training the model,
        so I wanna save and load this as well so that I can use it
        for continuing training my model in the second script
        """

model = Classifier()
model.cuda()
# training the model....
# evaluating and testing it....
# and saving as-

import dill
torch.save(model, 'model.pt', dill)

In the second script I’m trying to loading the model without any Classifier definition, so code in second script looks exactly like-

from torch import nn
import dill
model = torch.load('model.pt', 'cuda', dill)

Running it is showing error as “AttributeError: Can't get attribute 'Classifier' on ”.
I understood this error so when I copy and paste the Classifier class in this second script, then it works, but of course, this is not loading model without model definition.

So my question is how can I save and load my model so that I don’t have to deal with the original source code of the model definition.

One note- when I tried saving and loading the model build just using torch.nn.Sequential then it got loaded in the second scripts without any error.

1 Like

You could save the jitted model, which can then be loaded in other applications.
As far as I know, there is no way of storing the model directly without recreating the file structure.
Generally this workflow is also not recommended and you should store and load the state_dict instead.

Thanks, yeah as you said I now am able to save the model-
torch.jit.save(torch.jit.trace(model, (x)), "model.pth")
and load it like-
loaded_model = torch.jit.load("model.pth").

Though one trick I came up so that while loading I don’t have to deal with Classifier definition, by defining a load_model function inside the Classifier class, then have a three scripts structure like-

  • model.py (defining the Classifier only)
  • train.py (building and training model by importing Classifier from model.py)
  • load.py, and here loading model by importing Classifier from model.py and defining model like-
    from model import Classifier
    loaded_model = Classifier.load_model('model.pth')
    

No bad actually :wink: