How to load pretrained weights for specific layers

Hello expert PyTorch folks

I have a question regarding loading the pretrain weights for network.

Lets say I am using VGG16 net.
And i can use load_state_dict to reload the weights, pretty straight forward if my network stays the same!

Now lets say i want to reload the pre-trained vgg16 weights, but i change the architecture of the network in the following way.
I added 2 more layer to my input,
so for e.g. instead of doing
nn.Conv2d( 3, 64, 3, padding=1)
i will do
nn.Conv2d( 5, 64, 3, padding=1)

in the second case, when i want to use the load_state_dict, i get the following error:

RuntimeError: Error(s) in loading state_dict for ModuleList:
size mismatch for 0.weight: copying a param of torch.Size([64, 5, 3, 3]) from checkpoint, where the shape is torch.Size([64, 3, 3, 3]) in current model.

Does anyone know what i need to fix this issue?
is there any way to just update the first 3 parameters with the pre-trained weights?

To be more specific:


def VGG16(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
    conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
    conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
    layers += [pool5, conv6,
               nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
    return layers

InputChannel = 5

bases = {
    'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512]}
base = vgg(bases['A'], InputChannel)

vgg = nn.ModuleList(base)
vgg_weights = torch.load('weights/vgg16_reducedfc.pth')
print('Loading base network...')
vgg.load_state_dict(vgg_weights, strict=False)

for InputChannel = 3 I can load the pretrained weights, however, when I change the input channel to 5 for example), I get the error:


RuntimeError: Error(s) in loading state_dict for ModuleList:
	size mismatch for 0.weight: copying a param of torch.Size([64, 5, 3, 3]) from checkpoint, where the shape is torch.Size([64, 3, 3, 3]) in current model.

Please not that the vgg16_reducedfc was trained on InputChannel = 3.

Update:
I also tried this:

vgg = nn.ModuleList(base)
vgg_weights = torch.load('weights/vgg16_reducedfc.pth')
state = vgg.state_dict()
state.update(vgg_weights)
vgg.load_state_dict(state)

but still getting the same error :expressionless:

I did the following, but this is kinda stupid, but it apparently works.
Please let me know if there is any other solutions

import torch
import torch.nn as nn
from vgg import VGG16 as vgg


InputChannel = 5

bases = {
    'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512]}
base = vgg(bases['A'], InputChannel)

vgg = nn.ModuleList(base)
#reading pretrained weights
vgg_weights = torch.load('weights/vgg16_reducedfc.pth')
# Changing the pretrain weights to have the same 5 channels
Main = vgg_weights['0.weight']
SIZEMain = Main.size()
Zeros = torch.zeros(SIZEMain[0],2,SIZEMain[2],SIZEMain[3]).cpu()
NewMain = torch.cat((Main,Zeros),1) 
#Update the corresponding weight in state
state = vgg.state_dict()
state.update(vgg_weights)
state['0.weight'] = NewMain
#Give State to VGG
vgg.load_state_dict(state)
3 Likes