What is args in this case? Is it project-specific?
This code that I am using saves the model using torch.save(model)⌠in this case the model is load using args.pretrained = torch.load(args.pretrained)
when it is a single gpu. model is one of my models MyModelNet(nn.Module), but in the multi gpu case it is nn.DataParallel(MyModelNet(nn.Module))
Ok so that wouldnât really fix the loading problem but will help saving the correct state_dict() depending on whether your model is parallelized or not.
A more graceful solution is:
name = k.replace(".module", ââ) # removing â.molduleâ from key
As for me using the k[7:] wasnât properly removing the âmoduleâ.
I used your code to remove unexpected keys, however I can not get out from this error.
I tried all other tricks also.
Please give some hints to solve it. It will be very appreciated.
state_dict = torch.load("/media/Data/jcl-vb/output_dir/model_0000050.pth")
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[9:] # remove module.
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
ââ
However, it shows the error like:
RuntimeError: Error(s) in loading state_dict for GeneralizedRCNN:
Missing key(s) in state_dict: âbackbone.body.stem.conv1.weightâ, âbackbone.body.stem.bn1.weightâ, âbackbone.body.stem.bn1.biasâ, âbackbone.body.stem.bn1.running_meanâ, âbackbone.body.stem.bn1.running_varâ, âbackbone.body.layer1.0.downsample.0.weightâ, âbackbone.body.layer1.0.downsample.1.weightâ, âbackbone.body.layer1.0.downsample.1.biasâ, âbackbone.body.layer1.0.downsample.1.running_meanâ, âbackbone.body.layer1.0.downsample.1.running_varâ, âbackbone.body.layer1.0.conv1.weightâ, âbackbone.body.layer1.0.bn1.weightâ, âbackbone.body.layer1.0.bn1.biasâ, âbackbone.body.layer1.0.bn1.running_meanâ, âbackbone.body.layer1.0.bn1.running_varâ, âbackbone.body.layer1.0.conv2.weightâ, âbackbone.body.layer1.0.bn2.weightâ, âbackbone.body.layer1.0.bn2.biasâ, âbackbone.body.layer1.0.bn2.running_meanâ, âbackbone.body.layer1.0.bn2.running_varâ, âbackbone.body.layer1.0.conv3.weightâ, âbackbone.body.layer1.0.bn3.weightâ, âbackbone.body.layer1.0.bn3.biasâ, âbackbone.body.layer1.0.bn3.running_meanâ, âbackbone.body.layer1.0.bn3.running_varâ, âbackbone.body.layer1.1.conv1.weightâ, âbackbone.body.layer1.1.bn1.weightâ, âbackbone.body.layer1.1.bn1.biasâ, âbackbone.body.layer1.1.bn1.running_meanâ, âbackbone.body.layer1.1.bn1.running_varâ, âbackbone.body.layer1.1.conv2.weightâ, âbackbone.body.layer1.1.bn2.weightâ, âbackbone.body.layer1.1.bn2.biasâ, âbackbone.body.layer1.1.bn2.running_meanâ, âbackbone.body.layer1.1.bn2.running_varâ, âbackbone.body.layer1.1.conv3.weightâ, âbackbone.body.layer1.1.bn3.weightâ, âbackbone.body.layer1.1.bn3.biasâ, âbackbone.body.layer1.1.bn3.running_meanâ, âbackbone.body.layer1.1.bn3.running_varâ, âbackbone.body.layer1.2.conv1.weightâ, âbackbone.body.layer1.2.bn1.weightâ, âbackbone.body.layer1.2.bn1.biasâ, âbackbone.body.layer1.2.bn1.running_meanâ, âbackbone.body.layer1.2.bn1.running_varâ, âbackbone.body.layer1.2.conv2.weightâ, âbackbone.body.layer1.2.bn2.weightâ, âbackbone.body.layer1.2.bn2.biasâ, âbackbone.body.layer1.2.bn2.running_meanâ, âbackbone.body.layer1.2.bn2.running_varâ, âbackbone.body.layer1.2.conv3.weightâ, âbackbone.body.layer1.2.bn3.weightâ, âbackbone.body.layer1.2.bn3.biasâ, âbackbone.body.layer1.2.bn3.running_meanâ, âbackbone.body.layer1.2.bn3.running_varâ, âbackbone.body.layer2.0.downsample.0.weightâ, âbackbone.body.layer2.0.downsample.1.weightâ, âbackbone.body.layer2.0.downsample.1.biasâ, âbackbone.body.layer2.0.downsample.1.running_meanâ, âbackbone.body.layer2.0.downsample.1.running_varâ, âbackbone.body.layer2.0.conv1.weightâ, âbackbone.body.layer2.0.bn1.weightâ, âbackbone.body.layer2.0.bn1.biasâ, âbackbone.body.layer2.0.bn1.running_meanâ, âbackbone.body.layer2.0.bn1.running_varâ, âbackbone.body.layer2.0.conv2.weightâ, âbackbone.body.layer2.0.bn2.weightâ, âbackbone.body.layer2.0.bn2.biasâ, âbackbone.body.layer2.0.bn2.running_meanâ, âbackbone.body.layer2.0.bn2.running_varâ, âbackbone.body.layer2.0.conv3.weightâ, âbackbone.body.layer2.0.bn3.weightâ, âbackbone.body.layer2.0.bn3.biasâ, âbackbone.body.layer2.0.bn3.running_meanâ, âbackbone.body.layer2.0.bn3.running_varâ, âbackbone.body.layer2.1.conv1.weightâ, âbackbone.body.layer2.1.bn1.weightâ, âbackbone.body.layer2.1.bn1.biasâ, âbackbone.body.layer2.1.bn1.running_meanâ, âbackbone.body.layer2.1.bn1.running_varâ, âbackbone.body.layer2.1.conv2.weightâ, âbackbone.body.layer2.1.bn2.weightâ, âbackbone.body.layer2.1.bn2.biasâ, âbackbone.body.layer2.1.bn2.running_meanâ, âbackbone.body.layer2.1.bn2.running_varâ, âbackbone.body.layer2.1.conv3.weightâ, âbackbone.body.layer2.1.bn3.weightâ, âbackbone.body.layer2.1.bn3.biasâ, âbackbone.body.layer2.1.bn3.running_meanâ, âbackbone.body.layer2.1.bn3.running_varâ, âbackbone.body.layer2.2.conv1.weightâ, âbackbone.body.layer2.2.bn1.weightâ, âbackbone.body.layer2.2.bn1.biasâ, âbackbone.body.layer2.2.bn1.running_meanâ, âbackbone.body.layer2.2.bn1.running_varâ, âbackbone.body.layer2.2.conv2.weightâ, âbackbone.body.layer2.2.bn2.weightâ, âbackbone.body.layer2.2.bn2.biasâ, âbackbone.body.layer2.2.bn2.running_meanâ, âbackbone.body.layer2.2.bn2.running_varâ, âbackbone.body.layer2.2.conv3.weightâ, âbackbone.body.layer2.2.bn3.weightâ, âbackbone.body.layer2.2.bn3.biasâ, âbackbone.body.layer2.2.bn3.running_meanâ, âbackbone.body.layer2.2.bn3.running_varâ, âbackbone.body.layer2.3.conv1.weightâ, âbackbone.body.layer2.3.bn1.weightâ, âbackbone.body.layer2.3.bn1.biasâ, âbackbone.body.layer2.3.bn1.running_meanâ, âbackbone.body.layer2.3.bn1.running_varâ, âbackbone.body.layer2.3.conv2.weightâ, âbackbone.body.layer2.3.bn2.weightâ, âbackbone.body.layer2.3.bn2.biasâ, âbackbone.body.layer2.3.bn2.running_meanâ, âbackbone.body.layer2.3.bn2.running_varâ, âbackbone.body.layer2.3.conv3.weightâ, âbackbone.body.layer2.3.bn3.weightâ, âbackbone.body.layer2.3.bn3.biasâ, âbackbone.body.layer2.3.bn3.running_meanâ, âbackbone.body.layer2.3.bn3.running_varâ, âbackbone.body.layer3.0.downsample.0.weightâ, âbackbone.body.layer3.0.downsample.1.weightâ, âbackbone.body.layer3.0.downsample.1.biasâ, âbackbone.body.layer3.0.downsample.1.running_meanâ, âbackbone.body.layer3.0.downsample.1.running_varâ, âbackbone.body.layer3.0.conv1.weightâ, âbackbone.body.layer3.0.bn1.weightâ, âbackbone.body.layer3.0.bn1.biasâ, âbackbone.body.layer3.0.bn1.running_meanâ, âbackbone.body.layer3.0.bn1.running_varâ, âbackbone.body.layer3.0.conv2.weightâ, âbackbone.body.layer3.0.bn2.weightâ, âbackbone.body.layer3.0.bn2.biasâ, âbackbone.body.layer3.0.bn2.running_meanâ, âbackbone.body.layer3.0.bn2.running_varâ, âbackbone.body.layer3.0.conv3.weightâ, âbackbone.body.layer3.0.bn3.weightâ, âbackbone.body.layer3.0.bn3.biasâ, âbackbone.body.layer3.0.bn3.running_meanâ, âbackbone.body.layer3.0.bn3.running_varâ, âbackbone.body.layer3.1.conv1.weightâ, âbackbone.body.layer3.1.bn1.weightâ, âbackbone.body.layer3.1.bn1.biasâ, âbackbone.body.layer3.1.bn1.running_meanâ, âbackbone.body.layer3.1.bn1.running_varâ, âbackbone.body.layer3.1.conv2.weightâ, âbackbone.body.layer3.1.bn2.weightâ, âbackbone.body.layer3.1.bn2.biasâ, âbackbone.body.layer3.1.bn2.running_meanâ, âbackbone.body.layer3.1.bn2.running_varâ, âbackbone.body.layer3.1.conv3.weightâ, âbackbone.body.layer3.1.bn3.weightâ, âbackbone.body.layer3.1.bn3.biasâ, âbackbone.body.layer3.1.bn3.running_meanâ, âbackbone.body.layer3.1.bn3.running_varâ, âbackbone.body.layer3.2.conv1.weightâ, âbackbone.body.layer3.2.bn1.weightâ, âbackbone.body.layer3.2.bn1.biasâ, âbackbone.body.layer3.2.bn1.running_meanâ, âbackbone.body.layer3.2.bn1.running_varâ, âbackbone.body.layer3.2.conv2.weightâ, âbackbone.body.layer3.2.bn2.weightâ, âbackbone.body.layer3.2.bn2.biasâ, âbackbone.body.layer3.2.bn2.running_meanâ, âbackbone.body.layer3.2.bn2.running_varâ, âbackbone.body.layer3.2.conv3.weightâ, âbackbone.body.layer3.2.bn3.weightâ, âbackbone.body.layer3.2.bn3.biasâ, âbackbone.body.layer3.2.bn3.running_meanâ, âbackbone.body.layer3.2.bn3.running_varâ, âbackbone.body.layer3.3.conv1.weightâ, âbackbone.body.layer3.3.bn1.weightâ, âbackbone.body.layer3.3.bn1.biasâ, âbackbone.body.layer3.3.bn1.running_meanâ, âbackbone.body.layer3.3.bn1.running_varâ, âbackbone.body.layer3.3.conv2.weightâ, âbackbone.body.layer3.3.bn2.weightâ, âbackbone.body.layer3.3.bn2.biasâ, âbackbone.body.layer3.3.bn2.running_meanâ, âbackbone.body.layer3.3.bn2.running_varâ, âbackbone.body.layer3.3.conv3.weightâ, âbackbone.body.layer3.3.bn3.weightâ, âbackbone.body.layer3.3.bn3.biasâ, âbackbone.body.layer3.3.bn3.running_meanâ, âbackbone.body.layer3.3.bn3.running_varâ, âbackbone.body.layer3.4.conv1.weightâ, âbackbone.body.layer3.4.bn1.weightâ, âbackbone.body.layer3.4.bn1.biasâ, âbackbone.body.layer3.4.bn1.running_meanâ, âbackbone.body.layer3.4.bn1.running_varâ, âbackbone.body.layer3.4.conv2.weightâ, âbackbone.body.layer3.4.bn2.weightâ, âbackbone.body.layer3.4.bn2.biasâ, âbackbone.body.layer3.4.bn2.running_meanâ, âbackbone.body.layer3.4.bn2.running_varâ, âbackbone.body.layer3.4.conv3.weightâ, âbackbone.body.layer3.4.bn3.weightâ, âbackbone.body.layer3.4.bn3.biasâ, âbackbone.body.layer3.4.bn3.running_meanâ, âbackbone.body.layer3.4.bn3.running_varâ, âbackbone.body.layer3.5.conv1.weightâ, âbackbone.body.layer3.5.bn1.weightâ, âbackbone.body.layer3.5.bn1.biasâ, âbackbone.body.layer3.5.bn1.running_meanâ, âbackbone.body.layer3.5.bn1.running_varâ, âbackbone.body.layer3.5.conv2.weightâ, âbackbone.body.layer3.5.bn2.weightâ, âbackbone.body.layer3.5.bn2.biasâ, âbackbone.body.layer3.5.bn2.running_meanâ, âbackbone.body.layer3.5.bn2.running_varâ, âbackbone.body.layer3.5.conv3.weightâ, âbackbone.body.layer3.5.bn3.weightâ, âbackbone.body.layer3.5.bn3.biasâ, âbackbone.body.layer3.5.bn3.running_meanâ, âbackbone.body.layer3.5.bn3.running_varâ, âbackbone.body.layer4.0.downsample.0.weightâ, âbackbone.body.layer4.0.downsample.1.weightâ, âbackbone.body.layer4.0.downsample.1.biasâ, âbackbone.body.layer4.0.downsample.1.running_meanâ, âbackbone.body.layer4.0.downsample.1.running_varâ, âbackbone.body.layer4.0.conv1.weightâ, âbackbone.body.layer4.0.bn1.weightâ, âbackbone.body.layer4.0.bn1.biasâ, âbackbone.body.layer4.0.bn1.running_meanâ, âbackbone.body.layer4.0.bn1.running_varâ, âbackbone.body.layer4.0.conv2.weightâ, âbackbone.body.layer4.0.bn2.weightâ, âbackbone.body.layer4.0.bn2.biasâ, âbackbone.body.layer4.0.bn2.running_meanâ, âbackbone.body.layer4.0.bn2.running_varâ, âbackbone.body.layer4.0.conv3.weightâ, âbackbone.body.layer4.0.bn3.weightâ, âbackbone.body.layer4.0.bn3.biasâ, âbackbone.body.layer4.0.bn3.running_meanâ, âbackbone.body.layer4.0.bn3.running_varâ, âbackbone.body.layer4.1.conv1.weightâ, âbackbone.body.layer4.1.bn1.weightâ, âbackbone.body.layer4.1.bn1.biasâ, âbackbone.body.layer4.1.bn1.running_meanâ, âbackbone.body.layer4.1.bn1.running_varâ, âbackbone.body.layer4.1.conv2.weightâ, âbackbone.body.layer4.1.bn2.weightâ, âbackbone.body.layer4.1.bn2.biasâ, âbackbone.body.layer4.1.bn2.running_meanâ, âbackbone.body.layer4.1.bn2.running_varâ, âbackbone.body.layer4.1.conv3.weightâ, âbackbone.body.layer4.1.bn3.weightâ, âbackbone.body.layer4.1.bn3.biasâ, âbackbone.body.layer4.1.bn3.running_meanâ, âbackbone.body.layer4.1.bn3.running_varâ, âbackbone.body.layer4.2.conv1.weightâ, âbackbone.body.layer4.2.bn1.weightâ, âbackbone.body.layer4.2.bn1.biasâ, âbackbone.body.layer4.2.bn1.running_meanâ, âbackbone.body.layer4.2.bn1.running_varâ, âbackbone.body.layer4.2.conv2.weightâ, âbackbone.body.layer4.2.bn2.weightâ, âbackbone.body.layer4.2.bn2.biasâ, âbackbone.body.layer4.2.bn2.running_meanâ, âbackbone.body.layer4.2.bn2.running_varâ, âbackbone.body.layer4.2.conv3.weightâ, âbackbone.body.layer4.2.bn3.weightâ, âbackbone.body.layer4.2.bn3.biasâ, âbackbone.body.layer4.2.bn3.running_meanâ, âbackbone.body.layer4.2.bn3.running_varâ, âbackbone.fpn.fpn_inner1.weightâ, âbackbone.fpn.fpn_inner1.biasâ, âbackbone.fpn.fpn_layer1.weightâ, âbackbone.fpn.fpn_layer1.biasâ, âbackbone.fpn.fpn_inner2.weightâ, âbackbone.fpn.fpn_inner2.biasâ, âbackbone.fpn.fpn_layer2.weightâ, âbackbone.fpn.fpn_layer2.biasâ, âbackbone.fpn.fpn_inner3.weightâ, âbackbone.fpn.fpn_inner3.biasâ, âbackbone.fpn.fpn_layer3.weightâ, âbackbone.fpn.fpn_layer3.biasâ, âbackbone.fpn.fpn_inner4.weightâ, âbackbone.fpn.fpn_inner4.biasâ, âbackbone.fpn.fpn_layer4.weightâ, âbackbone.fpn.fpn_layer4.biasâ, ârpn.anchor_generator.cell_anchors.0â, ârpn.anchor_generator.cell_anchors.1â, ârpn.anchor_generator.cell_anchors.2â, ârpn.anchor_generator.cell_anchors.3â, ârpn.anchor_generator.cell_anchors.4â, ârpn.head.conv.weightâ, ârpn.head.conv.biasâ, ârpn.head.cls_logits.weightâ, ârpn.head.cls_logits.biasâ, ârpn.head.bbox_pred.weightâ, ârpn.head.bbox_pred.biasâ, âroi_heads.box.feature_extractor.fc6.weightâ, âroi_heads.box.feature_extractor.fc6.biasâ, âroi_heads.box.feature_extractor.fc7.weightâ, âroi_heads.box.feature_extractor.fc7.biasâ, âroi_heads.box.predictor.cls_score.weightâ, âroi_heads.box.predictor.cls_score.biasâ, âroi_heads.box.predictor.bbox_pred.weightâ, âroi_heads.box.predictor.bbox_pred.biasâ.
Unexpected key(s) in state_dict: ââ.
Yup, same here. In my case itâs e.g. âsynth_reader.0.weightâ and âsynth_reader.module.0.weightâ, so the replace works like a charm.
Can you explain more clearly how to add a nn.DataParallel
temporarily in your network for loading purposes? e.g. can you provide a simple example?
I am new to pytorch, thanks so much!
you can use strict=False in load_state_dict. This can solved the issue.
model.load_state_dict(checkpoint['state_dict'], strict=False)
I would recommend caution in using strict=False
here. I tried this, and replacing module.
worked correctly and gave me the validation results as expected from my model, while strict=False
gave wrong results (even though PyTorch did not complain).
It worked! Huuuuge thanks.
As @arturs.polis also mentioned, be careful with this! For me, it removed the error message, but gave a completely wrong model without warning!
Simple but perfectly worked
@arturs.polis did you a found a way to correct the results using strict = false
?
It can solve this but also cause other problems. In my case, as @arturs.polis already pointed out, the model returned completely wrong results.
The solution provided by @fmassa worked fine for me.
Hello @fmassa , I am using your approach to change some of the names of the keys for the model. But I am getting error while loading the model. Following is my code and error message:
class LeNet(nn.Module):
def __init__(self, num_classes=43, input_channels=3):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
if 1 == num_classes:
# compatible with nn.BCELoss
self.softmax = nn.Sigmoid()
else:
# compatible with nn.CrossEntropyLoss
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
out = self.softmax(out)
return out
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher_model = LeNet()
checkpoint = torch.load('model_best.pth.tar', map_location=device)
state_dict =checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k == 'conv1.weight':
name = 'features.0.weight'
new_state_dict[name] = v
elif k == 'conv1.bias':
name = 'features.0.bias'
new_state_dict[name] = v
checkpoint['state_dict'] = new_state_dict
teacher_model.load_state_dict(checkpoint['state_dict'])
I am changing all the names but here, I have shown only two in order to save the space. It gives me the following error:
Error(s) in loading state_dict for LeNet:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias".
Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "classifier.0.weight", "classifier.0.bias", "classifier.2.weight", "classifier.2.bias", "classifier.4.weight", "classifier.4.bias".
Please help to solve the problem. Thank you.
Double post from here.
This is Gorgeous!!!
I solve my problem
You save my life!!
Thanks a lot from South Korea
As previously mentioned, you will come out with a partly loaded stat and you should avoid it!!!
Where did you add DataParallel?
it works for me ,thinks , love from china