Hello @ptrblck and @kaltu , I am facing similar problem. I want to change the names of two keys in the model. I tried the suggested method, but I am getting errors. My code snippet and error message are below:
class LeNet(nn.Module):
def __init__(self, num_classes=43, input_channels=3):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
if 1 == num_classes:
# compatible with nn.BCELoss
self.softmax = nn.Sigmoid()
else:
# compatible with nn.CrossEntropyLoss
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
out = self.softmax(out)
return out
teacher_model = LeNet() # get the model
checkpoint = torch.load('model_best.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
state_dict[key.replace("conv1.weight","features.0.weight"). replace("conv1.bias", "features.0.bias")] = state_dict.pop(key)
teacher_model.load_state_dict(checkpoint['state_dict'])
It is giving me the following error:
Error(s) in loading state_dict for LeNet:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias".