How to load a pretrained model with a different output dimension

Hello,

I would like to ask a basic question.

I save my training models using

torch.save(model.state_dict(), save_path)

and load them using

model.load_state_dict(torch.load(save_path))

My training data can get updated overtime, so to train a new model for the new data, I use a previous model as pretrained weights. The problem is if the new data has more (or fewer) classes, then I get an error:

size mismatch for fc.bias: copying a param of torch.Size([45]) from checkpoint, where the shape is torch.Size([44]) in current model.

My question is: How to load the pretrained model without the fully connected layer so that the above error does not occur?

Thank you very much for your help!

1 Like

It’s from a different discussion, but you could try to filter out only the parameters from your pretrained state_dict which are present in your current model: thread.

1 Like

@ptrblck Thanks, but it seems to me that your method cannot be applied to my case. For convenience let me quote it here:

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.)


pretrained_model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(20, 2))
]))
pretrained_dict = pretrained_model.state_dict()

model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(20, 2)),
    ('relu2', nn.ReLU()),
    ('fc3', nn.Linear(2, 2))
]))
# Initialize model
model.apply(weight_init)
model_dict = model.state_dict()

# Fiter out unneccessary keys
filtered_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)

If I understand it correctly then this code will load, from a pre-trained model, the weights of the layers that have the same names as in the main model. In my case, the names of the layers in both models may be exactly the same, only the dimensions of the last fully connected layers are different.

By the way, your answer there is very helpful for a very popular use case, so thanks for that as well!

3 Likes

Here is one not so elegant but get things done way. Why not just load the original and replace the last FC layer?

Here is an example using the pretrained AlexNet in PyTorch.

model = models.alexnet(pretrained)
classifier = list(model.classifier.children())
model.classifier = nn.Sequential(*classifier[:-1])
model.classifier.add_module(
    '6', nn.Linear(classifier[-1].in_features, num_classes)

Hope it helps.

Thanks. That seems to be a straightforward idea. I will have to save the whole model and not only the state dict though (otherwise I cannot load the pertained model).

I would be nice to be able to load the state dict and by pass the layers whose dimensions do not matched. A solution is to change this line of the code:

                if input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    error_msgs.append('size mismatch for {}: copying a param of {} from checkpoint, '
                                      'where the shape is {} in current model.'
                                      .format(key, input_param.shape, param.shape))
                    continue

to

import warnings
...
                if input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    message = 'size mismatch for {}: copying a param of {} from checkpoint, '
                                      'where the shape is {} in current model.'
                                      .format(key, input_param.shape, param.shape)
                    if strict:
                        error_msgs.append(message)
                    else:
                        warnings.warn(message)
                    continue
2 Likes

I’m not sure what you mean by saving the whole model. You can still save using state dict. Just that you need the original model definition code which is also needed to load state. For the case in my example,

# Original AlexNet
class AlexNet(nn.Module):
    ...

model = AlexNet()
model.load_state_dict(state)

# Replace the final FC layer
classifier = list(model.classifier.children())
model.classifier = nn.Sequential(*classifier[:-1])
model.classifier.add_module(
    '6', nn.Linear(classifier[-1].in_features, num_classes)

As for now, this is the only way without manually changing the PyTorch source code. I hope it clarifies my previous answer.

Sorry I was not clear enough in the previous post.

I meant that if only a state dict is given and we don’t know the dimension of the output (i.e. the number of classes), then it is not possible to load the state dict.

Let me give an example (with ResNet, because that’s what I’m using).

Suppose that that the dataset has changed and now we have new_num_classes = 100 classes, but we don’t know the previous old_num_classes. All the information we have is state_dict.pth. To define a model for training, I use the following code.

# Create the model and change the dimension of the output
model = torchvision.models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, new_num_classes)
# Load the pre-trained weights, only works if the dimensions are the same
model.load_state_dict(torch.load('state_dict.pth'))

If old_num_classes != new_num_classes then the last line will cause an error.

In conclusion, I have to either save the whole model using torch.save(the_model, PATH), or save only the state dict and the number of classes.

2 Likes

I have the same error in the example you gave, that is, I have only changed the class now ,new_num_classes = 100` classes, and now I want to load the model after fine-tuning, but I am getting an error, I don’t know how to change it? Have you solved it?

@xiao You need to know the old number of classes, then you can do this:

# Create the model and change the dimension of the output
model = torchvision.models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, old_num_classes)
# Load the pre-trained model, which has old_num_classes
model.load_state_dict(torch.load('state_dict.pth'))
# Now change the model to new_num_classes
model.fc = nn.Linear(num_ftrs, new_num_classes)
# Done
2 Likes

Thank you very much, maybe I didn’t express it clearly before, I fine-tuned the pre-training model, then changed the FC classes, then I saved the model, now I want to load it, there is a problem.

What does your code look like and what is the error?

Thanks,I fine-tuned the pre-training model, then saved my model after training my own data. The code for fine-tuning the pre-training model and saving the model is as follows:
######################################################################

Finetuning the convnet

model_ft = models.resnet101(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 11)
######################################################################

save the fnie-tuned model

            torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict(), 
                        'optimizer': optimizer.state_dict()},
                       './checkpoint/' + 'new.pth.tar')

But when i load the saved model ,like this:
checkpoint = torch.load("./checkpoint/new.pth.tar")
myresnet = model_ft.load_state_dict(checkpoint[‘state_dict’])

There have a error,I don’t know how to define the model,for I use the pretrained-model.error like this:
myresnet = model_ft.load_state_dict(checkpoint[‘state_dict’])
NameError: name ‘model_ft’ is not defined

I think the problem maybe happened like this:

  • The first time when you load the checkpoint, it has the old number of classes, and you need to finetune the fc layer to fit with checkpoint, load it then finetune the fc layer to your desire num_classes, and save new checkpoint.
  • The second training time you want to reload the new checkpoint created from first training time. But problem is now the num_classes of checkpoint is your desire num classes already. So if you do the same process like in first training time, it causes error.
    Hope it help ^^