Save and Load model

I have trained a model, I want save it and then reload it and use it to produce the output for new image.

I have found the function : torch.model(‘path’) ,but when I reload it it always have problem.

Can anyone give me some suggestions or a simple example?

Thank you so much.

1 Like

Use torch.save and torch.load:

torch.save(model, 'filename.pt')
model = torch.load('filename.pt')
7 Likes

@colesbury
I fine-tuning a resnet-50 model, and load the fine-tuned model as:

model = torch.load('./model_resnet50.pth.tar')

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

img = Image.open(IMG_URL)
img_tensor = preprocess(img)
img_tensor.unsqueeze_(0)
output = model(Variable(img_tensor))

However, when running the above code, there is an error:
TypeError: 'dict' object is not callable

The content in ‘model_resnet50.pth.tar’ is a dict, you should use model = torch.load('./model_resnet50.pth.tar')['state_dict']

@colesbury ok,thank you so much.

This covers all the pth files I’ve encountered or created…

        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
1 Like

Hi ,I’m trying to load model use model = torch.load(‘cnn4mnist.pt’), but I got error below:

Traceback (most recent call last):
  File "predict.py", line 11, in <module>
    model = torch.load('cnn4mnist.pt')
  File "C:\Users\kanch\Anaconda3\lib\site-packages\torch\serialization.py", line 231, in load
    return _load(f, map_location, pickle_module)
  File "C:\Users\kanch\Anaconda3\lib\site-packages\torch\serialization.py", line 379, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'CNN4MNIST' on <module '__main__' from 'predict.py'>

What should I do? I don’t really know where did I wrong?

what is the difference between the file .pkl and the file .pt ?

The PyTorch serialization format is built off of pickle (.pkl) but overrides some functionality to handle tensors that share the same backing storage.

I thought the format was in .pth not .pt

Its arbitrary actually, you can use anything.

1 Like

Is the problem solved? I also encountered the same problem, thank you.

Load the state_dict via:

model = MyModel()
state_dict = torch.load("last_brain1.pth")['state_dict']
model.load_state_dict(state_dict)
5 Likes

i am stuck here line 154


this solution not work

model is still a dict instead of a nn.Module class.
Could you print(model) after running my lines of code?

see

model should be initialized as your model class, e.g.:

class MyModel(nn.Module):
    def __init__(self):
        # your layer definitions here

    def forward(self, x):
        # your forward pass here

model = MyModel()
state_dict = torch.load("last_brain1.pth")['state_dict']
model.load_state_dict(state_dict)

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3)
self.fc1 = nn.Linear(64 14 14, 30)
self.fc2 = nn.Linear(30, 3)
# self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # output size: [batch_size, 32, 255, 255]
x = self.pool(F.relu(self.conv2(x))) # output size: [batch_size, 64, 126, 126]

x = x.view(-1, 64 14 14) # output size: [batch_size, 64126126]
x = F.relu(self.fc1(x))
x = self.fc2(x)
# x=self.softmax(x)
return x

x = torch.randn(1, 3, 64, 64) # (batch size or #of images,channels RGB,width,height)
model = Net()
output = model(x)
model = Net()
state_dict = torch.load(“last_brain1.pth”)[‘state_dict’]
model.load_state_dict(state_dict)
print(model)

here is my code but still not working

What kind error do you get and which line of code raises it?
Could you also show, how you’ve created last_brain1.pth?
Maybe 'state_dict' refers to something else than the model.state_dict()?