Hi, I am conducting on using layers of existing pretrained NN model for new NN model with additional module in existing NN model.
The pretrained NN model is as follow:
class DUNet(nn.Module):
def __init__(self, in_ch, in_ch2, out_ch, out_ch2, bilinear):
super(DUNet, self).__init__()
self.encoder1 = UNetDown(in_ch=in_ch, bilinear=bilinear)
self.encoder2 = UNetDown(in_ch=in_ch2, bilinear=bilinear)
self.decoder1= UNetUp(out_ch=out_ch, bilinear=bilinear)
self.decoder2= UNetUp(out_ch=out_ch2, bilinear=bilinear)
self.outc = OutConv(4, 3)
def forward(self, input_1, input_2):
f1_1, f2_1, f3_1, f4_1, f5_1 = self.encoder1(input_1)
f1_2, f2_2, f3_2, f4_2, f5_2 = self.encoder2(input_2)
f6_1, f7_1, f8_1, f9_1, recon_1 = self.decoder1(f1_1, f2_1, f3_1, f4_1, f5_1)
f6_2, f7_2, f8_2, f9_2, recon_2 = self.decoder2(f1_2, f2_2, f3_2, f4_2, f5_2)
concat = torch.cat([recon_1, recon_2], dim=1)
output = self.outc(concat)
return output, f6_1, f7_1, f8_1, f9_1, f6_2, f7_2, f8_2, f9_2
it is two branch U-Net structure.
And new NN model is as follow:
class DUNet_finetune(nn.Module):
def __init__(self, in_ch, in_ch2, out_ch, out_ch2, bilinear):
super(DUNet_finetune, self).__init__()
self.pretrained_DUNet = DUNet(in_ch, in_ch2, out_ch, out_ch2, bilinear)
self.encoder1 = self.pretrained_DUNet.encoder1
self.encoder2 = self.pretrained_DUNet.encoder2
... #(override the other variables as well)
# define additional module
# def new_module():
# ...
def forward(self, input_1, input_2):
f1_1, f2_1, f3_1, f4_1, f5_1 = self.encoder1(input_1)
f1_2, f2_2, f3_2, f4_2, f5_2 = self.encoder2(input_2)
f6_1, f7_1, f8_1, f9_1, recon_1 = self.decoder1(f1_1, f2_1, f3_1, f4_1, f5_1)
f6_2, f7_2, f8_2, f9_2, recon_2 = self.decoder2(f1_2, f2_2, f3_2, f4_2, f5_2)
# additional module operated in this section
#new_module()
#...
concat = torch.cat([recon_1, recon_2], dim=1)
output = self.outc(concat)
return output, f6_1, f7_1, f8_1, f9_1, f6_2, f7_2, f8_2, f9_2
when I training DUNet, I wrap the model with nn.DataParallel. (Actually, It doesn’t needed but I didn’t changed.)
When I train DUNet_finetune model that using pretrained DUNet’s layer, the training code is as follow:
net= DUNet_finetune(in_ch=3, in_ch2=1, out_ch=3, out_ch2=1, bilinear=False).cuda()
net= torch.nn.DataParallel(net)
if opt.pretrained_guided:
checkpoint = torch.load(PATH)
net.module.pretrained_DUNet.load_state_dict(checkpoint['model_state_dict'])
print('Use the Pretrained Network!')
And I get the error msg :
Error(s) in loading state_dict for Parallel:
Missing key(s) in state_dict: “encoder1.inc.conv_blocks.0.weight”, …
Unexpected key(s) in state_dict: “module.encoder1.inc.conv_blocks.0.weight”, …
I’ve solved this problem with
net.module.pretrained_DUNet.load_state_dict(checkpoint['model_state_dict'], strict=False)
by referring to the contents shown here.
But I don’t know what this means. I am concerned about whether the network will be learned as I want. I want to finetune the DUNet with additional module.
I don’t know if there is better solution (ex, train DUNet without wrapping nn.DataParallel or else).
Could you give me some advice?
Thank you very much.