Error when I use torch.save() and torch.load()

torch.save(model, 'model.pth')

model = torch.load('model.pth')

The docs say

This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

What exactly does available mean? Where should it be available? In which directory or path do I need to have it stored?

I’ve provided a little more detail here.

You should provide the code and the weights rather than the model.

The code and weights are pretty standard. Just a regular network on jupyter notebook
I wrote this model that converts an image into a vector, on colab

class ConvEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, (3, 3), padding=(1, 1))
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d((2, 2))

        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=(1, 1))
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d((2, 2))

        self.conv3 = nn.Conv2d(32, 64, (3, 3), padding=(1, 1))
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d((2, 2))

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)

        return x

Then I did

model = ConvEncoder()
torch.save(model, 'enc.pth')

But when I download that to my local system, and run

model = torch.load('enc.pth')

I get an error saying Model not found. However if save and load it in the same place, like a colab notebook,it does work. So I’m trying to figure out what is breaking here.

Pickle serialises the instances. It is supposed to provide minimal information to recover the object. Pytorch works using several modules and inheritances and I simply think the state serialised fails. Cannot tell you why without deeply stuying it.

That is why I still suggest to provide code + weights as the weights depends in an Ordered Dict and has much less dependencies that the nn.Modules.

1 Like

Thanks. It’s just that the API I’m building needs Pytorch to save and load the whole thing at once.

I’m building an explanation API that takes in a model, and an image and produces a Class activation map( CAM).

I want the API to work for any model that is supplied, this works for tf models, since you can save and load them quite simply, but not for Pytorch models, because I need to instantiate the model definition “before” I fire my API up, which is very inconvenient.

Although you could do that, you still need to provide the code so that pickle can reconstruct the object. In the end (in your case) it’s a matter of just instantiating that object generically with *args and **kwargs.

1 Like