Chexnet code error ((RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.densenet121.features.conv0.weight"))

I tried to execute the chexnet code of github but failed. (code github address:https://github.com/arnoweng/CheXNet) error message is as follows:
RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: “module.densenet121.features.conv0.weight”, …
I executed the code in the multiple gpu environment, and I added ‘CUDA_VISIBLE_DEVICES=0’ because I thought this was a problem, but the problem was not resolved. I want to know how to solve this problem.

You have to load it before sending it to dataparallel, as dataparallel generates a parent called module such that keys don’t match

@JuanFMontesinos 's answer will work, but only in case your use case is model (CPU or single GPU) -> model (multiple GPUs) (multi GPU to multi GPU should work normally, as should CPU / single GPU to CPU / single GPU).

If the use case is model (multiple GPUs) -> model(CPU or single GPU), you need this hack if the model was already saved:

substring = 'module.'
checkpoint_tmp = OrderedDict()
for k in checkpoint:
    new_k = k[len(substring):] if k.startswith(substring) else k
    checkpoint_tmp[new_k] = checkpoint[k]
checkpoint = checkpoint_tmp

This will remove the module. prefix that is appended to DataParallel models and load normally afterwards.

If the model was not already saved, you can prevent this from happening in the future by saving as follows (taken from this issue):

try:
    model_state_dict = model.module.state_dict()
except AttributeError:
    model_state_dict = model.state_dict()

which will remove the module. prefix before saving normally.

Thank you very much for your reply. I changed the code to the following.

if os.path.isfile(CKPT_PATH):
print("=> loading checkpoint")
checkpoint = torch.load(CKPT_PATH)
substring = ‘module.’
checkpoint_tmp = OrderedDict()
for k in checkpoint:
new_k = k[len(substring):] if k.startswith(substring) else k
checkpoint_tmp[new_k] = checkpoint[k]
checkpoint = checkpoint_tmp
model.load_state_dict(checkpoint[‘state_dict’])
print("=> loaded checkpoint")
else:
print("=> no checkpoint found")

But it’s still the same error. What do you think is the reason?

Seems like you are not the only one: https://github.com/arnoweng/CheXNet/issues/26

From the comments on this issue, I would say the pretrained model given in the repository was exported before PyTorch 0.4.0, and compatibility breaks. Probably the architecture itself is still valid, but not the saved model.

Could you paste the complete error message just to be sure?