Loading a pre-trained model

I have a pre-trained ResNet-50 network in model_file
I would like to load the model.

The following is my code:

def load_model(model_file, cuda):
    model_ft = resnet(False, 50)    //Resnet50
    model_ft.fc = torch.nn.Linear(model_ft.fc.in_features, 4)
    checkpoint = torch.load(model_file)
    model_ft.load_state_dict(checkpoint['model'].state_dict())

However, I get the error
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.cur_mean", "bn1.cur_var", "bn1.running_m", "bn1.running_m2", "bn1.running_logvar", "bn1.running_logvar2")).... Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", .....

Since the names of the layers of my pre-trained model (checkpoint) start as module. …, while the names of layers of the resnet (model_fc) do not start with module…

I have tried to solve the problem by:

    new = list(checkpoint.items()
    my_model_kvpair=model_ft.state_dict()
    count=0
    for key,value in model_ft.state_dict().items():
        layer_name, weights = new[count]      
        my_model_kvpair[key]=weights
        count+=1

However, I get the error: layer_name, weights = new[count], IndexError: list index out of range

(new[0] = (module): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(…
new[1] = out of range
)

type(checkpoint) = <class ‘dict’>
type(model_ft) = <class ‘networks.resnet.ResNet’>
type(model_ft.state_dict) = <class ‘dict’>

How can I load and use my pre-trained model?

It looks like you’ve saved the state_dict from a nn.DataParallel instance and try to load it to a vanilla module.
Here are some suggestions, how to deal with this issue.
You could just remove the module. names from the state_dict or save the state_dict as torch.save(model.module.state_dict(), PATH).

1 Like

Thank you for your answer. It helped!