Hello, everyone!
I have a question about PyTorch load mechanics, when we are using torch.save and torch.load. Let’s look at examples:
Suppose, I have a network:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class ReallySimpleModel(nn.Module):
def __init__(self, **params):
super().__init__()
self.net = nn.Sequential(OrderedDict([
("linear1", nn.Linear(20, 10)),
("bn1", nn.BatchNorm1d(10)),
("relu1", nn.ReLU()),
("linear2", nn.Linear(10, 5)),
("bn2", nn.BatchNorm1d(5)),
("relu2", nn.ReLU()),
("linear3", nn.Linear(5, 1)),
("bn3", nn.BatchNorm1d(1)),
("relu3", nn.Sigmoid()),
]))
def forward(self, x):
x = self.net.forward(x)
return x
After that, I create the instance of it with:
from modules.model import ReallySimpleModel
net = ReallySimpleModel()
pprint(net)
So, our net:
ReallySimpleModel(
(net): Sequential(
(linear1): Linear(in_features=20, out_features=10)
(bn1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True)
(relu1): ReLU()
(linear2): Linear(in_features=10, out_features=5)
(bn2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True)
(relu2): ReLU()
(linear3): Linear(in_features=5, out_features=1)
(bn3): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True)
(relu3): Sigmoid()
)
)
Now, I save it with:
torch.save(dict(
model=net,
model_state=net.state_dict()),
"net.checkpoint.pth.tar")
Okay, great, we save the model. We can even load it with:
checkpoint = torch.load("net.checkpoint.pth.tar")
net = checkpoint["model"]
pprint(net)
and the model structure would be correct. It works even without manual import of ReallySimpleModel
- very cool.
And the interesting part starts here. As we now, when we call torch.save
PyTorch use pickle to serialize the model and it’s source code. So, even if we change our model.py:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class ReallySimpleModel(nn.Module):
pass
Load works! And it’s really great! With the warning, that original source code have changed, but it works!
But, if we delete model.py
or delete ReallySimpleModel
from it - all goes wrong. ImportErrors or AttributeErrors will appear.
As you can see, PyTorch loading process doesn’t need any code of the model, I think it doesn’t need any code at all. But it needs the same projects structure to solve some “dict-keys” problem.
So, my question: is there any solution for this problem? Maybe we can somehow modulate project structure automatically? Personally, I cannot understand why do we use any project structure, when we load previous model?
When I load previous model, I don’t want to use current source code, I want previous to come back .