Can I change a custom resnet 18 architecture subtly and still use it in pre-trained mode?

Can I change a custom resnet 18 architecture and still use it in pre-trained = true mode? I am doing a subtle change in the architecture of a custom resnet18 and when i run it, i get the following error:
This is how the custom resnet18 is called:
model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)

The new change in the custom resnet18:

self.layer_attend1 =  nn.Sequential(nn.Conv2d(layers[0], layers[0], stride=2, padding=1, kernel_size=3),
                                  nn.AdaptiveAvgPool2d(1),
                                  nn.Softmax(1))

The output of running the model is:

/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
  File "main.py", line 352, in <module>
    main()    
  File "main.py", line 145, in main
    tnet.load_state_dict(checkpoint['state_dict'])
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
        Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias". /scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
  File "main.py", line 352, in <module>
    main()    
  File "main.py", line 145, in main
    tnet.load_state_dict(checkpoint['state_dict'])
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
        Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias".

So, how can you implement small architectural changes without retraining from the scratch every time?

I do not know how to apply the answer by @ptrblck to this case

or for example should I use this answer?

Can you at a high-level (not code-level) please provide some hint why this answer should work?

I am also loading a torch model like below:
checkpoint = torch.load(args.resume, encoding='latin1')

EDIT - If you come across this answer just see @Mona_Jalal 's answer below (use model.load_state_dict(state_dict, strict=False)).

@Mona_Jalal it’s not quite the same. It might help if you include a snippet showing exactly how you try to load your checkpoint but my guess is you’re doing something like:

state_dict = torch.load('checkpoint.pth')
model.load_state_dict(state_dict)

In which case the fix is to put the missing keys into state_dict. So do this:

state_dict = torch.load('checkpoint.pth')
state_dict['embeddingnet.embeddingnet.layer_attend1.0.weight'] = model.embeddingnet.embeddingnet.layer_attend1.0.weight
state_dict['embeddingnet.embeddingnet.layer_attend1.0.bias'] = model.embeddingnet.embeddingnet.layer_attend1.0.bias
model.load_state_dict(state_dict)

… or something like that, just making sure you’re pointing to the correct submodule. More generally you could do:

loaded_state_dict = torch.load('checkpoint.pth')
model_state_dict = model.state_dict()
for k, v in model_state_dict.items():
    if k not in loaded_state_dict:
        loaded_state_dict[k] = v
model.load_state_dict(state_dict)

I haven’t tested any of this just now so you may need to fiddle.

PS: Just a side note. You don’t need pretrained=True when creating the model, as you are just going to load up another set of weights anyway.

1 Like

I set strict to False here and it is working. I didn’t change anything else.

tnet.load_state_dict(checkpoint['state_dict'], strict=False)