Loading pytorch model without a code

Hello, everyone!

I have a question about PyTorch load mechanics, when we are using torch.save and torch.load. Let’s look at examples:

Suppose, I have a network:

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


class ReallySimpleModel(nn.Module):
    def __init__(self, **params):
        super().__init__()
        
        self.net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(20, 10)),
            ("bn1", nn.BatchNorm1d(10)),
            ("relu1", nn.ReLU()),
            ("linear2", nn.Linear(10, 5)),
            ("bn2", nn.BatchNorm1d(5)),
            ("relu2", nn.ReLU()),
            ("linear3", nn.Linear(5, 1)),
            ("bn3", nn.BatchNorm1d(1)),
            ("relu3", nn.Sigmoid()),
        ]))
    
    def forward(self, x):
        x = self.net.forward(x)
        return x

After that, I create the instance of it with:

from modules.model import ReallySimpleModel
net = ReallySimpleModel()
pprint(net)

So, our net:

ReallySimpleModel(
  (net): Sequential(
    (linear1): Linear(in_features=20, out_features=10)
    (bn1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True)
    (relu1): ReLU()
    (linear2): Linear(in_features=10, out_features=5)
    (bn2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True)
    (relu2): ReLU()
    (linear3): Linear(in_features=5, out_features=1)
    (bn3): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True)
    (relu3): Sigmoid()
  )
)

Now, I save it with:

torch.save(dict(
    model=net, 
    model_state=net.state_dict()), 
           "net.checkpoint.pth.tar")

Okay, great, we save the model. We can even load it with:

checkpoint = torch.load("net.checkpoint.pth.tar")
net = checkpoint["model"]
pprint(net)

and the model structure would be correct. It works even without manual import of ReallySimpleModel - very cool.

And the interesting part starts here. As we now, when we call torch.save PyTorch use pickle to serialize the model and it’s source code. So, even if we change our model.py:

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


class ReallySimpleModel(nn.Module):
    pass

Load works! And it’s really great! With the warning, that original source code have changed, but it works!
But, if we delete model.py or delete ReallySimpleModel from it - all goes wrong. ImportErrors or AttributeErrors will appear.

As you can see, PyTorch loading process doesn’t need any code of the model, I think it doesn’t need any code at all. But it needs the same projects structure to solve some “dict-keys” problem.

So, my question: is there any solution for this problem? Maybe we can somehow modulate project structure automatically? Personally, I cannot understand why do we use any project structure, when we load previous model?
When I load previous model, I don’t want to use current source code, I want previous to come back :smile:.

1 Like

There are limitations to loading pytorch model without code.

First limitation:
We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).

For example:

import foo

class MyModel(...):
    def forward(input):
        foo.bar(input)

Here the package foo is not saved in the model checkpoint.

Second limitation:
There are limitations on robustly serializing python constructs. For example the default picklers cannot serialize lambdas. There are helper packages that can serialize more python constructs than the standard, but they still have limitations. Dill is one such package.

Given these limitations, there is no robust way to have torch.load work without having the original source files.

5 Likes

Thank you for your answer.

Nevertheless, few more questions about the limitations. I know, the most correct model serialisation is really challenging problem, but what if we use some tricks?

For example, large variety of models can be described by nn.Sequential with the forward pass looking like this:

def forward(self, x):
    x = self.net.forward(x)
    return x

For these “types” of models all needed source code will be saved. And with other hack like saving the project structure, the model should loads without many problems, is’t it?

in this limited case, it might work if your Sequential itself is not containing any layers from other files.

Thanks for replying.

It has limitations without the code but I think PyTorch models can store all the needed computation graph by itself. Just like TensorFlow, MXNet and other frameworks.

This is really needed for the general serving service. We try to implement the service to load PyTorch models with user’s model files but not source files.

1 Like

Found this topic by googling, is there anyway to load model without including source code for pytorch?

If you export to TorchScript you can load it without code.
I wish PyTorch also had a runtime (not just library but set of utilities to make it easy to read exported models and run them in a session), then there would be really no need to source code.

3 Likes

Dig into the jit capability, but found roadblock.
Got errors like:
RuntimeError: forward() Expected a value of type ‘Tensor’ for argument ‘ctx’ but instead found type ‘FunctionBackward’.
Inferred ‘ctx’ to be of type ‘Tensor’ because it was not annotated with an explicit type.
Position: 0

Seems autograd script is still WIP [https://github.com/pytorch/pytorch/pull/22582]
Would be appreciate if anyone can help.

1 Like

We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).

I guess this has been changed recently?

After changing my source code folder name from src to <package-name> (e.g. abc), I got the following error:

File “/home/user/.env/python3/lib/python3.5/site-packages/torch/serialization.py”, line 613, in _load
result = unpickler.load()
ImportError: No module named ‘src’

Could you please tell me if there is a simple fix for the .pth files that I saved? I would like to keep <package-name> instead of src for the code.

Thank you very much in advance!

Thank you very much. Your finding is interesting and help solve my problem (some lib dependencies when deserializing a model)

The .pth file checkpoint has the computation graph? I see now. The whole time I thought you need the class module that defines the forward(x) to define the computation graph. Then my question is:

  1. In some case I need a output hook for output from certain layer instead of end/ final layer in neural network. How do I create this hook when loading in C++? I’m looking at loading C++ model using libtorch

Thank you very much.