Loading trained model using load_state_dict


num_classes = dset.num_classes
resnet_model18 = torchvision.models.resnet18(pretrained=True)
num_ftrs = resnet_model18.fc.in_features
resnet_model18.fc = nn.Linear(num_ftrs, num_classes)
resnet_model18 = resnet_model18.to(device)
...
resnet_model18_ft = train_model(resnet_model18, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=2)

torch.save(resnet_model18.state_dict(), 'models/checkpoint_{}'.format(time.strftime("%Y%m%d-%H%M%S")))

After several hours …

import torch
import torchvision
import torch.nn as nn

PATH = 'models/checkpoint_20190501-005720'

model = torchvision.models.resnet18(pretrained=False)
num_classes = 203094 # num_classes from train
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

model = model.load_state_dict(torch.load(PATH))

But this returns None. Am I missing something here?

You don’t have to expect output from that function. It overwrites in-place model, does not generate a new one

Can you explain that a little more? I am trying to understand TheModelClass from the docs : https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference

Load_state_edict is a bound method of the class. It has no return, that’s why you get none.
Once you have defined the model just do
Model.loadstate… but don’t do
Model=model.loadstwte…