Hi, there:
I’ve encountered this problem and got stucked for a while. I have a labeled image dataset in a considerable large scale and I chose to train a vgg16 on it just starting from pytorch’s imagenet example.
I firstly organize data into three splits, namely train, val, test; under each of them are bunches of subdirectory organized by class labels, like:
train
label0
label1
...
val
label0
label1
...
test
file0
file1
...
and the command:
export CUDA_VISIBLE_DEVICES=device_id
python3 main.py /path/to/my/dataset -a vgg16 -b 32 --lr 0.001
and the training seems to be fine — with nearly 90% of top-5 accuracy. The model file name is model_best.pth.tar
After that I would like to infer some images using my model, it fails with the follow error:
RuntimeError: Error(s) in loading state_dict for VGG:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias".
Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.17.weight", "features.module.17.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.24.weight", "features.module.24.bias", "features.module.26.weight", "features.module.26.bias", "features.module.28.weight", "features.module.28.bias".
Could anyone give me some advice?
Thanks in advance.
EDIT: the loading snippet:
import torch
from torchvision import models
model = models.__dict__[args.arch]() # arch is fed as 'vgg16'
model.cuda()
checkpoint = torch.load(model_file_name) # ie, model_best.pth.tar
model.load_state_dict(checkpoint['state_dict'])