I trained a model and saved the state_dict that I am trying to load again. However, I am getting a “missing keys in state_dict” while loading the model again, even though it has the exact same architecture.
The “missing keys” and “unexpected keys” are also very strange. It looks like the only difference between missing keys and unexpected keys is that the missing keys have an extra ‘0’ in them, as follows:
missing key: “feature_extractor_trainable.0.0.conv1.weight”
unexpected key: “feature_extractor_trainable.0.conv1.weight”
All missing keys have the same issue: there is an extra 0 in the parameter name.
Any idea of what might be causing this?
Thanks.
IliasPap
(Ilias Pap)
March 10, 2020, 10:26am
2
That extra 0 is probably from a torch.Sequential () module
A code with your model and saving function would be useful
This is the model architecture:
class Net(nn.Module):
def __init__(self, n_classes=4, l1=5, l2=6):
super(Net, self).__init__()
resnet = torchvision.models.resnet152(pretrained=True)
num_features = resnet.fc.in_features
feature_extractor = nn.Sequential(*list(resnet.children())[:l1])
feature_extractor_trainable = nn.Sequential(*list(resnet.children())[l1:l2])
for param in feature_extractor.parameters():
param.requires_grad = False
self.feature_extractor = feature_extractor
self.feature_extractor_trainable = feature_extractor_trainable
self.conv = nn.Sequential(*list(resnet.children())[l2:-1])
self.classifier = nn.Linear(num_features, n_classes)
def forward(self, inputs):
pre_img_features = self.feature_extractor(inputs[:,:3,:,:])
post_img_features = self.feature_extractor(inputs[:,3:,:,:])
pre_img_features = self.feature_extractor_trainable(pre_img_features)
post_img_features = self.feature_extractor_trainable(post_img_features)
out = post_img_features - pre_img_features
del pre_img_features, post_img_features
out = self.conv(out)
out = torch.squeeze(out)
out = self.classifier(out)
return out
IliasPap
(Ilias Pap)
March 10, 2020, 1:45pm
4
Your model state_dict is not saving the feature_extractor module beacuse you set param_require_grad=False
That’s not the case, and the state_dict
will include all registered parameters and buffers to restore the model.
@Khubaib_Siddiqui
Your model works fine using this code:
model = Net()
sd = model.state_dict()
model = Net()
model.load_state_dict(sd)
Did you change anything regarding the nn.Sequential
usage as @IliasPap suggested?