Loading pytorch model without a code


(Sergey Kolesnikov) #1

Hello, everyone!

I have a question about PyTorch load mechanics, when we are using torch.save and torch.load. Let’s look at examples:

Suppose, I have a network:

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


class ReallySimpleModel(nn.Module):
    def __init__(self, **params):
        super().__init__()
        
        self.net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(20, 10)),
            ("bn1", nn.BatchNorm1d(10)),
            ("relu1", nn.ReLU()),
            ("linear2", nn.Linear(10, 5)),
            ("bn2", nn.BatchNorm1d(5)),
            ("relu2", nn.ReLU()),
            ("linear3", nn.Linear(5, 1)),
            ("bn3", nn.BatchNorm1d(1)),
            ("relu3", nn.Sigmoid()),
        ]))
    
    def forward(self, x):
        x = self.net.forward(x)
        return x

After that, I create the instance of it with:

from modules.model import ReallySimpleModel
net = ReallySimpleModel()
pprint(net)

So, our net:

ReallySimpleModel(
  (net): Sequential(
    (linear1): Linear(in_features=20, out_features=10)
    (bn1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True)
    (relu1): ReLU()
    (linear2): Linear(in_features=10, out_features=5)
    (bn2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True)
    (relu2): ReLU()
    (linear3): Linear(in_features=5, out_features=1)
    (bn3): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True)
    (relu3): Sigmoid()
  )
)

Now, I save it with:

torch.save(dict(
    model=net, 
    model_state=net.state_dict()), 
           "net.checkpoint.pth.tar")

Okay, great, we save the model. We can even load it with:

checkpoint = torch.load("net.checkpoint.pth.tar")
net = checkpoint["model"]
pprint(net)

and the model structure would be correct. It works even without manual import of ReallySimpleModel - very cool.

And the interesting part starts here. As we now, when we call torch.save PyTorch use pickle to serialize the model and it’s source code. So, even if we change our model.py:

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


class ReallySimpleModel(nn.Module):
    pass

Load works! And it’s really great! With the warning, that original source code have changed, but it works!
But, if we delete model.py or delete ReallySimpleModel from it - all goes wrong. ImportErrors or AttributeErrors will appear.

As you can see, PyTorch loading process doesn’t need any code of the model, I think it doesn’t need any code at all. But it needs the same projects structure to solve some “dict-keys” problem.

So, my question: is there any solution for this problem? Maybe we can somehow modulate project structure automatically? Personally, I cannot understand why do we use any project structure, when we load previous model?
When I load previous model, I don’t want to use current source code, I want previous to come back :smile:.


#2

There are limitations to loading pytorch model without code.

First limitation:
We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).

For example:

import foo

class MyModel(...):
    def forward(input):
        foo.bar(input)

Here the package foo is not saved in the model checkpoint.

Second limitation:
There are limitations on robustly serializing python constructs. For example the default picklers cannot serialize lambdas. There are helper packages that can serialize more python constructs than the standard, but they still have limitations. Dill is one such package.

Given these limitations, there is no robust way to have torch.load work without having the original source files.


(Sergey Kolesnikov) #3

Thank you for your answer.

Nevertheless, few more questions about the limitations. I know, the most correct model serialisation is really challenging problem, but what if we use some tricks?

For example, large variety of models can be described by nn.Sequential with the forward pass looking like this:

def forward(self, x):
    x = self.net.forward(x)
    return x

For these “types” of models all needed source code will be saved. And with other hack like saving the project structure, the model should loads without many problems, is’t it?


#4

in this limited case, it might work if your Sequential itself is not containing any layers from other files.


(tobe) #5

Thanks for replying.

It has limitations without the code but I think PyTorch models can store all the needed computation graph by itself. Just like TensorFlow, MXNet and other frameworks.

This is really needed for the general serving service. We try to implement the service to load PyTorch models with user’s model files but not source files.