Hello expert PyTorch folks
I have a question regarding loading the pretrain weights for network.
Lets say I am using VGG16 net.
And i can use load_state_dict
to reload the weights, pretty straight forward if my network stays the same!
Now lets say i want to reload the pre-trained vgg16 weights, but i change the architecture of the network in the following way.
I added 2 more layer to my input,
so for e.g. instead of doing
nn.Conv2d( 3, 64, 3, padding=1)
i will do
nn.Conv2d( 5, 64, 3, padding=1)
in the second case, when i want to use the load_state_dict
, i get the following error:
RuntimeError: Error(s) in loading state_dict for ModuleList:
size mismatch for 0.weight: copying a param of torch.Size([64, 5, 3, 3]) from checkpoint, where the shape is torch.Size([64, 3, 3, 3]) in current model.
Does anyone know what i need to fix this issue?
is there any way to just update the first 3 parameters with the pre-trained weights?
To be more specific:
def VGG16(cfg, i, batch_norm=False):
layers = []
in_channels = i
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
elif v == 'C':
layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
layers += [pool5, conv6,
nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
return layers
InputChannel = 5
bases = {
'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512]}
base = vgg(bases['A'], InputChannel)
vgg = nn.ModuleList(base)
vgg_weights = torch.load('weights/vgg16_reducedfc.pth')
print('Loading base network...')
vgg.load_state_dict(vgg_weights, strict=False)
for InputChannel = 3
I can load the pretrained weights, however, when I change the input channel to 5 for example), I get the error:
RuntimeError: Error(s) in loading state_dict for ModuleList:
size mismatch for 0.weight: copying a param of torch.Size([64, 5, 3, 3]) from checkpoint, where the shape is torch.Size([64, 3, 3, 3]) in current model.
Please not that the vgg16_reducedfc
was trained on InputChannel = 3.
Update:
I also tried this:
vgg = nn.ModuleList(base)
vgg_weights = torch.load('weights/vgg16_reducedfc.pth')
state = vgg.state_dict()
state.update(vgg_weights)
vgg.load_state_dict(state)
but still getting the same error