PyTorch Pretrained VGG19 KeyError

I’m working on fine tuning the first 10 layers of VGG19 net to extract features from images. But I’m getting the below error which I couldn’t find a get around:

Traceback (most recent call last):
  File "TODO-train_from_scratch.py", line 390, in <module>
    main()
  File "TODO-train_from_scratch.py", line 199, in main
    model.load_state_dict(weights_load)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 339, in load_state_dict
    raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 

Corresponding snippet from training code is:

# create model
vgg19 = models.vgg19(pretrained = True)

vgg19_state_dict = vgg19.state_dict()

vgg19_keys = vgg19_state_dict.keys()    

model = get_model()

weights_load = {}   

for i in range(20):
    weights_load[model.state_dict().keys()[i]] = vgg19_state_dict[vgg19_keys[i]]

model.load_state_dict(weights_load)
model = torch.nn.DataParallel(model).cuda()

I think it’s because your weights_loss's order is not kept. collections.OrderedDict may help you

Could you make it more clear? I couldn’t understand exactly what I should change.

Adding to moskomule’s suggestion:

You save the weights in the dictionary weights_load in order. However, in Python a regular dictionary does not track the insertion order. When you print the weights_load dict, you will see that it will likely have a different order than VGG weight dict. The error occurs when you want to load an unordered set of weights into the model.

To make sure a dictionary tracks insertion order, use collection.OrderedDict. Simply change

weights_load = {}

with

import collections
weights_load = collections.OrderedDict()

That should maintain the weights order when loading it into model.

Hope it works.

Daniel

Thanks for the detailed explanation, however I got the same error despite the change.

After loaded first 10 layers of VGG19 to weights_load, you can fill missing keys with model's keys. Simply add this loop after your for loop:

for i in range(20, len(model.state_dict().keys())):
    weights_load[model.state_dict().keys()[i]] = model.state_dict()[model.state_dict().keys()[i]]

Fazil