I trained a model and saved the state_dict that I am trying to load again. However, I am getting a “missing keys in state_dict” while loading the model again, even though it has the exact same architecture.
The “missing keys” and “unexpected keys” are also very strange. It looks like the only difference between missing keys and unexpected keys is that the missing keys have an extra ‘0’ in them, as follows:
missing key: “feature_extractor_trainable.0.0.conv1.weight”
unexpected key: “feature_extractor_trainable.0.conv1.weight”
All missing keys have the same issue: there is an extra 0 in the parameter name.
Any idea of what might be causing this?
That extra 0 is probably from a torch.Sequential () module
A code with your model and saving function would be useful
This is the model architecture:
def __init__(self, n_classes=4, l1=5, l2=6):
resnet = torchvision.models.resnet152(pretrained=True)
num_features = resnet.fc.in_features
feature_extractor = nn.Sequential(*list(resnet.children())[:l1])
feature_extractor_trainable = nn.Sequential(*list(resnet.children())[l1:l2])
for param in feature_extractor.parameters():
param.requires_grad = False
self.feature_extractor = feature_extractor
self.feature_extractor_trainable = feature_extractor_trainable
self.conv = nn.Sequential(*list(resnet.children())[l2:-1])
self.classifier = nn.Linear(num_features, n_classes)
def forward(self, inputs):
pre_img_features = self.feature_extractor(inputs[:,:3,:,:])
post_img_features = self.feature_extractor(inputs[:,3:,:,:])
pre_img_features = self.feature_extractor_trainable(pre_img_features)
post_img_features = self.feature_extractor_trainable(post_img_features)
out = post_img_features - pre_img_features
del pre_img_features, post_img_features
out = self.conv(out)
out = torch.squeeze(out)
out = self.classifier(out)
Your model state_dict is not saving the feature_extractor module beacuse you set param_require_grad=False
That’s not the case, and the
state_dict will include all registered parameters and buffers to restore the model.
Your model works fine using this code:
model = Net()
sd = model.state_dict()
model = Net()
Did you change anything regarding the
nn.Sequential usage as @IliasPap suggested?