Save and Load model

I have trained a model, I want save it and then reload it and use it to produce the output for new image.

I have found the function : torch.model(‘path’) ,but when I reload it it always have problem.

Can anyone give me some suggestions or a simple example?

Thank you so much.

1 Like

Use and torch.load:, '')
model = torch.load('')

I fine-tuning a resnet-50 model, and load the fine-tuned model as:

model = torch.load('./model_resnet50.pth.tar')

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
preprocess = transforms.Compose([

img =
img_tensor = preprocess(img)
output = model(Variable(img_tensor))

However, when running the above code, there is an error:
TypeError: 'dict' object is not callable

The content in ‘model_resnet50.pth.tar’ is a dict, you should use model = torch.load('./model_resnet50.pth.tar')['state_dict']

@colesbury ok,thank you so much.

This covers all the pth files I’ve encountered or created…

        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
1 Like

Hi ,I’m trying to load model use model = torch.load(‘’), but I got error below:

Traceback (most recent call last):
  File "", line 11, in <module>
    model = torch.load('')
  File "C:\Users\kanch\Anaconda3\lib\site-packages\torch\", line 231, in load
    return _load(f, map_location, pickle_module)
  File "C:\Users\kanch\Anaconda3\lib\site-packages\torch\", line 379, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'CNN4MNIST' on <module '__main__' from ''>

What should I do? I don’t really know where did I wrong?

what is the difference between the file .pkl and the file .pt ?

The PyTorch serialization format is built off of pickle (.pkl) but overrides some functionality to handle tensors that share the same backing storage.

I thought the format was in .pth not .pt

Its arbitrary actually, you can use anything.

1 Like

Is the problem solved? I also encountered the same problem, thank you.

Load the state_dict via:

model = MyModel()
state_dict = torch.load("last_brain1.pth")['state_dict']

i am stuck here line 154

this solution not work

model is still a dict instead of a nn.Module class.
Could you print(model) after running my lines of code?


model should be initialized as your model class, e.g.:

class MyModel(nn.Module):
    def __init__(self):
        # your layer definitions here

    def forward(self, x):
        # your forward pass here

model = MyModel()
state_dict = torch.load("last_brain1.pth")['state_dict']

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3)
self.fc1 = nn.Linear(64 14 14, 30)
self.fc2 = nn.Linear(30, 3)
# self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # output size: [batch_size, 32, 255, 255]
x = self.pool(F.relu(self.conv2(x))) # output size: [batch_size, 64, 126, 126]

x = x.view(-1, 64 14 14) # output size: [batch_size, 64126126]
x = F.relu(self.fc1(x))
x = self.fc2(x)
# x=self.softmax(x)
return x

x = torch.randn(1, 3, 64, 64) # (batch size or #of images,channels RGB,width,height)
model = Net()
output = model(x)
model = Net()
state_dict = torch.load(“last_brain1.pth”)[‘state_dict’]

here is my code but still not working

What kind error do you get and which line of code raises it?
Could you also show, how you’ve created last_brain1.pth?
Maybe 'state_dict' refers to something else than the model.state_dict()?