Load dataparallel model

hi,i am loading the model under DataParallel,below is my code:

#training 
model = build_model()
model = nn.DataParallel(model)
model.cuda()
....training...
torch.save(model,name)

#eval
net = torch.load(name)

i save the whole model under DataParallel form,not just state_dict,and i load the model directly for eval,i am wondering is this right?will it get the same performance as save the state_dict and load the state_dict? thank you so much!!

You should get the same performance.
However, your current approach assumes that all source files contain the same definitions and are in the same location. Also, torch.load would load the nn.DataParallel model and might assume the same number of GPUs in your system (not sure about it and you would have to test it).
Given these disadvantages I always recommend to store the state_dict in order to get some flexibility of loading the model later.

ok ,i understand thank you!

hi! i use torch.save(net) to save model structure and weights ,but i change the model strcucture later,if i use newnet = torch.load(net) to load the saved .pth file now,will the newnet forward dataflow flow as i change the model structure before or after?thanks !!!

I would assume that the new model definition will be used, but I never tested it.
Note that I’m not using this approach, as it might break in several ways. E.g. you might not be able to load the model again, if you change the file structure or the definition too much.

Creating the new model instance directly and just loading the state_dict is the cleaner way.

i understand,but the reason why i use the torch.save() not the state_dict is that i hope to load the model and weight directly while don`t need to define the model first,because the my model structure may change many times,if i just save the state dict when i was training, i may not remember the exactly structure version when i evaluate the model,but load state dict need to define the model structure first which may cause the state dict between new model and loaded model do not match

I don’t think you can avoid this issue.
When you are saving and loading the state_dict, you would need to create the model first, that’s correct.
The model definition might have changed and you might get mismatches in the state_dict.

However, in your current approach, the same mismatch might happen and might be hidden behind the torch.load model.
E.g. this simple use case demonstrates it:

# model.py
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, 3, 1, 1)
        self.fc = nn.Linear(3*24*24, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# save.py
import torch

from model import MyModel

net = MyModel()
x = torch.randn(1, 1, 24, 24)
out = net(x)

torch.save(net, 'tmp.pt')

# load.py
import torch

net = torch.load('tmp.pt')
x = torch.randn(1, 1, 24, 24)
out = net(x)
print(out.shape)

After changing the linear layer in model.py to nn.Linear(3*24*24, 10) and executing load.py I get:

/opt/conda/lib/python3.6/site-packages/torch/serialization.py:644: SourceChangeWarning: source code of class 'model.MyModel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.

Which already sounds hard to debug, but at least the code is still running. Note that the code still returns an output of [1, 1], which corresponds to the initial definition not the new one, which might be even harder to debug.

After changing the name of self.fc to self.fc_new and executing load.py I get:

/opt/conda/lib/python3.6/site-packages/torch/serialization.py:644: SourceChangeWarning: source code of class 'model.MyModel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
Traceback (most recent call last):
  File "load.py", line 5, in <module>
    out = net(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/workspace/src/model.py", line 13, in forward
    x = self.fc_new(x)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 621, in __getattr__
    type(self).__name__, name))
torch.nn.modules.module.ModuleAttributeError: 'MyModel' object has no attribute 'fc_new'

As already said, use it at your own risk, but I would strongly advice against it, as I think debugging these issues in a “real” model might take a lot of time.