I don’t think you can avoid this issue.
When you are saving and loading the state_dict
, you would need to create the model first, that’s correct.
The model definition might have changed and you might get mismatches in the state_dict
.
However, in your current approach, the same mismatch might happen and might be hidden behind the torch.load
model.
E.g. this simple use case demonstrates it:
# model.py
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 3, 3, 1, 1)
self.fc = nn.Linear(3*24*24, 1)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# save.py
import torch
from model import MyModel
net = MyModel()
x = torch.randn(1, 1, 24, 24)
out = net(x)
torch.save(net, 'tmp.pt')
# load.py
import torch
net = torch.load('tmp.pt')
x = torch.randn(1, 1, 24, 24)
out = net(x)
print(out.shape)
After changing the linear layer in model.py
to nn.Linear(3*24*24, 10)
and executing load.py
I get:
/opt/conda/lib/python3.6/site-packages/torch/serialization.py:644: SourceChangeWarning: source code of class 'model.MyModel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
Which already sounds hard to debug, but at least the code is still running. Note that the code still returns an output of [1, 1]
, which corresponds to the initial definition not the new one, which might be even harder to debug.
After changing the name of self.fc
to self.fc_new
and executing load.py
I get:
/opt/conda/lib/python3.6/site-packages/torch/serialization.py:644: SourceChangeWarning: source code of class 'model.MyModel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Traceback (most recent call last):
File "load.py", line 5, in <module>
out = net(x)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 577, in __call__
result = self.forward(*input, **kwargs)
File "/workspace/src/model.py", line 13, in forward
x = self.fc_new(x)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 621, in __getattr__
type(self).__name__, name))
torch.nn.modules.module.ModuleAttributeError: 'MyModel' object has no attribute 'fc_new'
As already said, use it at your own risk, but I would strongly advice against it, as I think debugging these issues in a “real” model might take a lot of time.