This covers all the pth files I’ve encountered or created…
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
Hi ,I’m trying to load model use model = torch.load(‘cnn4mnist.pt’), but I got error below:
Traceback (most recent call last):
File "predict.py", line 11, in <module>
model = torch.load('cnn4mnist.pt')
File "C:\Users\kanch\Anaconda3\lib\site-packages\torch\serialization.py", line 231, in load
return _load(f, map_location, pickle_module)
File "C:\Users\kanch\Anaconda3\lib\site-packages\torch\serialization.py", line 379, in _load
result = unpickler.load()
AttributeError: Can't get attribute 'CNN4MNIST' on <module '__main__' from 'predict.py'>
What should I do? I don’t really know where did I wrong?
model should be initialized as your model class, e.g.:
class MyModel(nn.Module):
def __init__(self):
# your layer definitions here
def forward(self, x):
# your forward pass here
model = MyModel()
state_dict = torch.load("last_brain1.pth")['state_dict']
model.load_state_dict(state_dict)
What kind error do you get and which line of code raises it?
Could you also show, how you’ve created last_brain1.pth?
Maybe 'state_dict' refers to something else than the model.state_dict()?