How to transfer weight from a pretrained model to a different model based on the pretrained

I have 2 model, one is a base model and another one which has a slightly different architecture from the base model which I called new model. So I am trying to load the weight from base to new model using load_state_dict but I keep getting unexpected key error on the new module part even though I already initialize the weight.

This is my old model

class SfsNetPipeline(nn.Module):
    """ SfSNet Pipeline
    """
    def __init__(self):
        super(SfsNetPipeline, self).__init__()

        self.conv_model            = baseFeaturesExtractions()
        self.normal_residual_model = NormalResidualBlock()
        self.normal_gen_model      = NormalGenerationNet()
        self.albedo_residual_model = AlbedoResidualBlock()
        self.albedo_gen_model      = AlbedoGenerationNet()
        self.light_estimator_model = LightEstimator()

    def get_face(self, sh, normal, albedo):
        shading = get_shading(normal, sh)
        recon   = reconstruct_image(shading, albedo)
        return recon

    def forward(self, face):
        # Following is training pipeline
        # 1. Pass Image from Conv Model to extract features

        out_features = self.conv_model(face)
        # 2 a. Pass Conv features through Normal Residual
        out_normal_features = self.normal_residual_model(out_features)
        # 2 b. Pass Conv features through Albedo Residual
        out_albedo_features = self.albedo_residual_model(out_features)
        # 3 a. Generate Normal
        predicted_normal = self.normal_gen_model(out_normal_features)
        # 3 b. Generate Albedo
        predicted_albedo = self.albedo_gen_model(out_albedo_features)

        all_features = torch.cat((out_features, out_normal_features, out_albedo_features), dim=1)
        # Predict SH
        predicted_sh = self.light_estimator_model(all_features)
        # 4. Generate shading
        out_shading = get_shading(predicted_normal, predicted_sh)
        out_recon = reconstruct_image(out_shading, predicted_albedo)
        #out_recon_color = self.color_model(out_recon)


        return predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon

And this is the new model (I comment the part where it is different)

class SfSRGBNetPipeline(nn.Module):
    """ SfSNet Pipeline
    """
    def __init__(self):
        super(SfSRGBNetPipeline, self).__init__()

        self.conv_model            = baseFeaturesExtractions()
        self.normal_residual_model = NormalResidualBlock()
        self.normal_gen_model      = NormalGenerationNet()
        self.albedo_residual_model = AlbedoResidualBlock()
        self.albedo_gen_model      = AlbedoGenerationNet()
        self.light_estimator_model = LightEstimatorNew()
        #punyaku
        #self.color_model = ColorCorrector()
        self.light_residual_model = LightFeatures()
        self.light_gen_model      = LightGenerationNet()

    def get_face(self, sh, normal, albedo):
        shading = get_shading(normal, sh)
        recon   = reconstruct_image(shading, albedo)
        return recon

    def forward(self, face):
        # Following is training pipeline
        # 1. Pass Image from Conv Model to extract features

        out_features = self.conv_model(face)
        # 2 a. Pass Conv features through Normal Residual
        out_normal_features = self.normal_residual_model(out_features)
        # 2 b. Pass Conv features through Albedo Residual
        out_albedo_features = self.albedo_residual_model(out_features)
        # 3 a. Generate Normal
        predicted_normal = self.normal_gen_model(out_normal_features)
        # 3 b. Generate Albedo
        predicted_albedo = self.albedo_gen_model(out_albedo_features)
        
        #Added new module here
        out_light_features = self.light_residual_model(out_features)
        predicted_correction = self.light_gen_model(out_light_features)
        # 3 c. Estimate lighting
        # First, concat conv, normal and albedo features over channels dimension
        all_features = torch.cat((out_light_features, out_normal_features, out_albedo_features), dim=1)
        # Predict SH
        predicted_sh = self.light_estimator_model(all_features)
        # 4. Generate shading
        out_shading = get_shading(predicted_normal, predicted_sh)
        # 5. Reconstruction of image
        out_recon = reconstruct_image(out_shading, predicted_albedo)
        #out_recon_color = self.color_model(out_recon)


        return predicted_normal, predicted_albedo, predicted_correction, predicted_sh, out_shading, out_recon

and this is how I load the model

sfs_net_ori      = SfsNetPipeline()
sfs_netku        = SfSRGBNetPipeline()

if use_cuda:
        sfs_net_model = nn.DataParallel(sfs_net_model)
        sfs_net_model = sfs_net_model.cuda()
        sfs_netku.apply(weights_init)        
        sfs_netku = nn.DataParallel(sfs_netku)
        sfs_netku = sfs_netku.cuda()
        checkpoint = torch.load(model_dir + 'sfs_net_model_9.pkl')
        sfs_net_pretrained_old = checkpoint['model_state_dict']
        sfs_net_pretrained_new = sfs_netku.state_dict()  
        sfs_netku.load_state_dict(sfs_net_pretrained_old)
1 Like

Do you get the same error, if you use strict=False in load_state_dict?
This option should warn you, but ignore missing or unexpected keys.

Nope it solved. Thank you about that, if I may ask another question, should I still load the optimizer from the old pretrained model?

Iā€™m not sure how the training will behave, if the optimizer uses internal running estimate of some parameters while others are empty, so your best bet is to try it out and please report your findings. :wink:

1 Like