What is the strict=False factor intended for model.load_state_dict?

Hi, I am conducting on using layers of existing pretrained NN model for new NN model with additional module in existing NN model.

The pretrained NN model is as follow:

class DUNet(nn.Module):
    def __init__(self, in_ch, in_ch2, out_ch, out_ch2, bilinear):
        super(DUNet, self).__init__()
        self.encoder1 = UNetDown(in_ch=in_ch, bilinear=bilinear)
        self.encoder2 = UNetDown(in_ch=in_ch2, bilinear=bilinear)
        self.decoder1= UNetUp(out_ch=out_ch, bilinear=bilinear)
        self.decoder2= UNetUp(out_ch=out_ch2, bilinear=bilinear)
        self.outc = OutConv(4, 3)

    def forward(self, input_1, input_2):
        f1_1, f2_1, f3_1, f4_1, f5_1 = self.encoder1(input_1)
        f1_2, f2_2, f3_2, f4_2, f5_2 = self.encoder2(input_2)
        f6_1, f7_1, f8_1, f9_1, recon_1 = self.decoder1(f1_1, f2_1, f3_1, f4_1, f5_1)
        f6_2, f7_2, f8_2, f9_2, recon_2 = self.decoder2(f1_2, f2_2, f3_2, f4_2, f5_2)

        concat = torch.cat([recon_1, recon_2], dim=1)
        output = self.outc(concat)
        return output, f6_1, f7_1, f8_1, f9_1, f6_2, f7_2, f8_2, f9_2

it is two branch U-Net structure.
And new NN model is as follow:

class DUNet_finetune(nn.Module):
    def __init__(self, in_ch, in_ch2, out_ch, out_ch2, bilinear):
        super(DUNet_finetune, self).__init__()
        self.pretrained_DUNet = DUNet(in_ch, in_ch2, out_ch, out_ch2, bilinear)
        self.encoder1 = self.pretrained_DUNet.encoder1
        self.encoder2 = self.pretrained_DUNet.encoder2
        ... #(override the other variables as well)

    # define additional module
    # def new_module():
    # ...

    def forward(self, input_1, input_2):
        f1_1, f2_1, f3_1, f4_1, f5_1 = self.encoder1(input_1)
        f1_2, f2_2, f3_2, f4_2, f5_2 = self.encoder2(input_2)
        f6_1, f7_1, f8_1, f9_1, recon_1 = self.decoder1(f1_1, f2_1, f3_1, f4_1, f5_1)
        f6_2, f7_2, f8_2, f9_2, recon_2 = self.decoder2(f1_2, f2_2, f3_2, f4_2, f5_2)
        # additional module operated in this section
        #new_module()
        #...

        concat = torch.cat([recon_1, recon_2], dim=1)
        output = self.outc(concat)
        return output, f6_1, f7_1, f8_1, f9_1, f6_2, f7_2, f8_2, f9_2

when I training DUNet, I wrap the model with nn.DataParallel. (Actually, It doesn’t needed but I didn’t changed.)
When I train DUNet_finetune model that using pretrained DUNet’s layer, the training code is as follow:

net= DUNet_finetune(in_ch=3, in_ch2=1, out_ch=3, out_ch2=1, bilinear=False).cuda()
net= torch.nn.DataParallel(net)

if opt.pretrained_guided:
    checkpoint = torch.load(PATH)
    net.module.pretrained_DUNet.load_state_dict(checkpoint['model_state_dict'])
    print('Use the Pretrained Network!')

And I get the error msg :
Error(s) in loading state_dict for Parallel:
Missing key(s) in state_dict: “encoder1.inc.conv_blocks.0.weight”, …
Unexpected key(s) in state_dict: “module.encoder1.inc.conv_blocks.0.weight”, …

I’ve solved this problem with

net.module.pretrained_DUNet.load_state_dict(checkpoint['model_state_dict'], strict=False)

by referring to the contents shown here.

But I don’t know what this means. I am concerned about whether the network will be learned as I want. I want to finetune the DUNet with additional module.

I don’t know if there is better solution (ex, train DUNet without wrapping nn.DataParallel or else).
Could you give me some advice?
Thank you very much.

You did not solve the issue, but you are explicitly ignoring the mismatches and most likely no parameters are loaded at all.
strict=False allows you to skip the mismatching keys and can be used for the linked use case, where the user added one additional module.
In your case the state_dict contains the .module keys added by nn.DataParallel while you are trying to load it into the raw model inside the nn.DataParallel wrapper.
Make sure to store and load the same state_dict, ideally from the internal model (not from the nn.DataParallel model).

I understand you said “train DUNet without wrapping with nn.DataParallel.” I’ll give it a try. Thank you for your advice.