Transfer Learning Missing key(s) in state_dict & Unexpected key(s) in state_dict

Hi everyone,

I have seen similar topics to this one however I don’t haven’t been able to work out a solution from the posts on this board. Thanks in advance for your patience!

I’m training a model on a virtual instance on AWS instance:
Deep Learning AMI (Ubuntu) Version 20.0 (ami-0d0ff0945ae093aea)

I know that it is GPU enabled so i can train on CUDA:

# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')

else:
    print('CUDA is available!  Training on GPU ...')
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Making sure it’s on CUDA before training:

if train_on_gpu:
    model.cuda()
    cudnn.benchmark = True

Saving the model:

torch.save({'arch': 'vgg19',
            'state_dict': model.state_dict(),
            'class_to_idx': model.class_to_idx},
            'classifier.pth')

I made a function that loads the model:

def load_model(checkpoint_path):
    chpt = torch.load(checkpoint_path)
    
    
    
    if chpt['arch'] == 'vgg19':
        model = models.vgg19(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
    
    model.class_to_idx = chpt['class_to_idx']
    
    model.cpu()
    
    
    # Create the classifier
    classifier = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(25088, 4096)),
                          ('relu', nn.ReLU()),
                          ('fc2', nn.Linear(4096, 102)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))
    # Put the classifier on the pretrained network
    model.classifier = classifier
    
    
    model.load_state_dict(chpt['state_dict'])
    
    return model

loading the mode:

model = load_model('vgg19.pth')

The error below is perplexing. The missing keys and the unexpected keys are the same with the addition of module. in front.
How can this be?

RuntimeError: Error(s) in loading state_dict for VGG:
	Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.16.weight", "features.16.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.23.weight", "features.23.bias", "features.25.weight", "features.25.bias", "features.28.weight", "features.28.bias", "features.30.weight", "features.30.bias", "features.32.weight", "features.32.bias", "features.34.weight", "features.34.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias". 
	Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.features.5.weight", "module.features.5.bias", "module.features.7.weight", "module.features.7.bias", "module.features.10.weight", "module.features.10.bias", "module.features.12.weight", "module.features.12.bias", "module.features.14.weight", "module.features.14.bias", "module.features.16.weight", "module.features.16.bias", "module.features.19.weight", "module.features.19.bias", "module.features.21.weight", "module.features.21.bias", "module.features.23.weight", "module.features.23.bias", "module.features.25.weight", "module.features.25.bias", "module.features.28.weight", "module.features.28.bias", "module.features.30.weight", "module.features.30.bias", "module.features.32.weight", "module.features.32.bias", "module.features.34.weight", "module.features.34.bias", "module.classifier.fc1.weight", "module.classifier.fc1.bias", "module.classifier.fc2.weight", "module.classifier.fc2.bias".

Your issue is probably related to this.

Hi @ptrblck thanks for the quick response, appreciate it!
I had to modify the code as shown below.
I know that the classifier of vgg19 needs to be modified to fit this particular case.
My follow up question is does setting Strict=False for load_state_dict distort the results?

def load_model(checkpoint_path):
    chpt = torch.load(checkpoint_path)
    
    
    #Load the model
    if chpt['arch'] == 'vgg19':
        model = models.vgg19(pretrained=True)
        
        for param in model.parameters():
            param.requires_grad = False
            
    #Load the state dict
    state_dict=chpt['state_dict']
    new_state_dict= OrderedDict()
    for k, v in state_dict.items():
        name=k[7:] #remove 'module.' of DataParallel
        new_state_dict[name]=v

    model.load_state_dict(new_state_dict, strict=False)
    
    #load the classes
    model.class_to_idx = chpt['class_to_idx']
            
    # Create the classifier
    classifier = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(25088, 4096)),
                          ('relu', nn.ReLU()),
                          ('fc2', nn.Linear(4096, 102)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))
    
    model.classifier = classifier
        
        
    return model

strict=False doesn’t enforce that the keys from your state_dict match all keys in your model’s state_dict. What do you mean by “distort the results”?
Do you see any weird results?

I’m thinking cases where we might have more or less dictionary keys, and if the key names do not exactly match.
So far though the model seems to run as expected.
Thank you so much for your help and happy new year!!

EDIT:
@ptrblck I guess the issue I am seeing is losing the trained weights after training the model.

Say for example this is the fully trained model:

model_ft, train_losses, test_losses = train_model(model, criterion, optimizer, exp_lr_scheduler,
                       num_epochs=15)

Saving the model state_dict:

model_dict={'arch': 'vgg19',
            'state_dict': model_ft.state_dict(), 
            'class_to_idx': model_ft.class_to_idx,
            'weights': model_ft.features}

Now say I reload the model:

model_vgg19 = load_model('vgg19.pth')

A quick comparison between model_vgg19 and model_ft shows that the predictions are way off, as if model_vgg19 did not learn anything form the previous one.

In your load_model method you are removing the module names of nn.DataParallel. Is this still needed in your current code snippet? Could you check the new_state_dict and see if the names are still valid?

I wen through the function step by step in hopes of figuring out what’s going on as below

checkpoint = torch.load('vgg19.pth')
state_dict=checkpoint['state_dict']
temp_model = models.vgg19(pretrained=True)
temp_model.load_state_dict(state_dict, strict=True)

The error thrown is

RuntimeError                              Traceback (most recent call last)
<ipython-input-35-d8320f134783> in <module>
----> 1 temp_model.load_state_dict(state_dict, strict=True)

~/anaconda3/envs/deep-learning/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    767         if len(error_msgs) > 0:
    768             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769                                self.__class__.__name__, "\n\t".join(error_msgs)))
    770 
    771     def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for VGG:
	Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 
	Unexpected key(s) in state_dict: "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias". 

The last nn.Sequential block in vgg16’s model.classifier is defined as:

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )

while yours is:

nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(25088, 4096)),
    ('relu', nn.ReLU()),
    ('fc2', nn.Linear(4096, 102)),
    ('output', nn.LogSoftmax(dim=1))
]))

which is throwing this error.
Based on your architecture it looks like only the parameters of the first linear layer can be loaded.
Could you try to just load these parameters separately?

Try

model.module.load_state_dict(checkpoint[‘state_dict’])