I have a question about PyTorch load mechanics, when we are using torch.save and torch.load. Let’s look at examples:
Suppose, I have a network:
import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict class ReallySimpleModel(nn.Module): def __init__(self, **params): super().__init__() self.net = nn.Sequential(OrderedDict([ ("linear1", nn.Linear(20, 10)), ("bn1", nn.BatchNorm1d(10)), ("relu1", nn.ReLU()), ("linear2", nn.Linear(10, 5)), ("bn2", nn.BatchNorm1d(5)), ("relu2", nn.ReLU()), ("linear3", nn.Linear(5, 1)), ("bn3", nn.BatchNorm1d(1)), ("relu3", nn.Sigmoid()), ])) def forward(self, x): x = self.net.forward(x) return x
After that, I create the instance of it with:
from modules.model import ReallySimpleModel net = ReallySimpleModel() pprint(net)
So, our net:
ReallySimpleModel( (net): Sequential( (linear1): Linear(in_features=20, out_features=10) (bn1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True) (relu1): ReLU() (linear2): Linear(in_features=10, out_features=5) (bn2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True) (relu2): ReLU() (linear3): Linear(in_features=5, out_features=1) (bn3): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True) (relu3): Sigmoid() ) )
Now, I save it with:
torch.save(dict( model=net, model_state=net.state_dict()), "net.checkpoint.pth.tar")
Okay, great, we save the model. We can even load it with:
checkpoint = torch.load("net.checkpoint.pth.tar") net = checkpoint["model"] pprint(net)
and the model structure would be correct. It works even without manual import of
ReallySimpleModel - very cool.
And the interesting part starts here. As we now, when we call
torch.save PyTorch use pickle to serialize the model and it’s source code. So, even if we change our model.py:
import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict class ReallySimpleModel(nn.Module): pass
Load works! And it’s really great! With the warning, that original source code have changed, but it works!
But, if we delete
model.py or delete
ReallySimpleModel from it - all goes wrong. ImportErrors or AttributeErrors will appear.
As you can see, PyTorch loading process doesn’t need any code of the model, I think it doesn’t need any code at all. But it needs the same projects structure to solve some “dict-keys” problem.
So, my question: is there any solution for this problem? Maybe we can somehow modulate project structure automatically? Personally, I cannot understand why do we use any project structure, when we load previous model?
When I load previous model, I don’t want to use current source code, I want previous to come back .