Missing keys & unexpected keys in state_dict when loading self trained model

Hi, there:

I’ve encountered this problem and got stucked for a while. I have a labeled image dataset in a considerable large scale and I chose to train a vgg16 on it just starting from pytorch’s imagenet example.

I firstly organize data into three splits, namely train, val, test; under each of them are bunches of subdirectory organized by class labels, like:


and the command:

export CUDA_VISIBLE_DEVICES=device_id
python3 main.py /path/to/my/dataset -a vgg16 -b 32 --lr 0.001 

and the training seems to be fine — with nearly 90% of top-5 accuracy. The model file name is model_best.pth.tar

After that I would like to infer some images using my model, it fails with the follow error:

RuntimeError: Error(s) in loading state_dict for VGG:
        Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias".
        Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.17.weight", "features.module.17.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.24.weight", "features.module.24.bias", "features.module.26.weight", "features.module.26.bias", "features.module.28.weight", "features.module.28.bias".

Could anyone give me some advice?

Thanks in advance.

EDIT: the loading snippet:

import torch
from torchvision import models
model = models.__dict__[args.arch]() # arch is fed as 'vgg16'
checkpoint = torch.load(model_file_name) # ie, model_best.pth.tar

The Traceback:

Traceback (most recent call last):
  File "classifier.py", line 98, in <module>
  File "classifier.py", line 78, in main
    _im, im = Classifier(args.model).infer(args.image)
  File "classifier.py", line 48, in __init__
  File "/home/xxx/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))

It seems you’ve used nn.DataParallel to save the model.
You could wrap your current model again in nn.DataParallel or just remove the .module keys.
Here is a similar thread with some suggestions.


Thanks for your suggestions! @ptrblck

Can we load the state_dict of a previously trained NN for a new NN that has some changes in its structure. For example, I need the state_dict of previously trained NN partially, which i successfully did. Now, I add one more layer to my new NN and it shows error that the state_dict of NN which was loaded does not contain this output layer. How to proceed?


Could you try to load the state_dict using model.load_state_dict(torch.load(PATH), strict=False)?


Thank u so much @ptrblck, it worked perfectly!

In my case it would not throw any errors anymore, but it wouldn’t correctly load the state_dict either.
In the end I just had to do this to just remove the “.module” part.

Can you explain more clearly how to wrap the current model in nn.Dataparallel? Like can you give an example?
I have the error Missing key(s) in state_dict: and actually when I save the model, I just use torch.save().
I am new to pytorch, thanks so much!

You could just wrap the model in nn.DataParallel and push it to the device:

model = Model(input_size, output_size)
model = nn.DataParallel(model)

I would not recommend to save the model directly, but instead its state_dict as explained here.
Also, after you’ve wrapped the model in nn.DataParallel, the original model will be accessible via model.module, so you might want to store the state_dict via torch.save(model.module.state_dict(), 'file_name.pt').


I use torch.save(model.state_dict(), 'file_name.pt') to save the model. If I don’t use nn.DataParallel to save model, I also don’t need it when I load it, right?
If so, when I torch.save(model.state_dict(), 'mymodel.pt') in one py file during training, I try to load it in a new file with model.load_state_dict(torch.load('mymodel.pt)) , it gives me the error
RuntimeError: Error(s) in loading state_dict for model: Missing key(s) in state_dict: "l1.W", "l1.b", "l2.W", "l2.b".
My model is a self-defined model with two self-constructed layers l1 and l2, each has parameter W and b.
Can you give me any answer to it? Thanks!

Could you post the model definition, please?

Yes, that is correct.

I also am facing the same issue. I don’t use nn.DataParallel.
I am using a slightly modified version of [this repo] in a Kaggle notebook https://github.com/aitorzip/PyTorch-CycleGAN.
Here’s how I save:

                'epoch': epoch,
                'model_state_dict': netG_A2B.state_dict(),
                'optimizer_state_dict': optimizer_G.state_dict(),
                'loss_histories': loss_histories,
                }, f'netG_A2B_{epoch:03d}.pth')

Here’s how I load:

checkpoint = torch.load(weights_path)['model_state_dict']

And here’s a representative snippet of the (longer) error message:

Missing key(s) in state_dict: "1.weight", "1.bias", "4.weight", "4.bias", "7.weight", "7.bias".
Unexpected key(s) in state_dict: "model.1.weight", "model.1.bias", "model.4.weight", "model.4.bias", "model.7.weight", "model.7.bias". 

I opted for this method to fix it:

checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict']
for key in list(checkpoint.keys()):
    if 'model.' in key:
        checkpoint[key.replace('model.', '')] = checkpoint[key]
        del checkpoint[key]

I was facing the same issue when I was using PyTorch lightning.

@williamFalcon Please take note that thius was the way I solved the problem.

You can replace module keys in state _dict as follows:-

pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}

Ideally, if you use DataParallel save the checkpoint file as follows for inference:-
torch.save(model.module.state_dict(), 'model_ckpt.pt') .

This might also be useful for running inference using CPU, at a later time.

Ref: https://stackoverflow.com/a/61854807/3177661


Thanks, you save my day

It worked for me, no more error.

works for me, thank you

This worked. thanks a lot!!