I have 2 model, one is a base model and another one which has a slightly different architecture from the base model which I called new model. So I am trying to load the weight from base to new model using load_state_dict but I keep getting unexpected key error on the new module part even though I already initialize the weight.
This is my old model
class SfsNetPipeline(nn.Module):
""" SfSNet Pipeline
"""
def __init__(self):
super(SfsNetPipeline, self).__init__()
self.conv_model = baseFeaturesExtractions()
self.normal_residual_model = NormalResidualBlock()
self.normal_gen_model = NormalGenerationNet()
self.albedo_residual_model = AlbedoResidualBlock()
self.albedo_gen_model = AlbedoGenerationNet()
self.light_estimator_model = LightEstimator()
def get_face(self, sh, normal, albedo):
shading = get_shading(normal, sh)
recon = reconstruct_image(shading, albedo)
return recon
def forward(self, face):
# Following is training pipeline
# 1. Pass Image from Conv Model to extract features
out_features = self.conv_model(face)
# 2 a. Pass Conv features through Normal Residual
out_normal_features = self.normal_residual_model(out_features)
# 2 b. Pass Conv features through Albedo Residual
out_albedo_features = self.albedo_residual_model(out_features)
# 3 a. Generate Normal
predicted_normal = self.normal_gen_model(out_normal_features)
# 3 b. Generate Albedo
predicted_albedo = self.albedo_gen_model(out_albedo_features)
all_features = torch.cat((out_features, out_normal_features, out_albedo_features), dim=1)
# Predict SH
predicted_sh = self.light_estimator_model(all_features)
# 4. Generate shading
out_shading = get_shading(predicted_normal, predicted_sh)
out_recon = reconstruct_image(out_shading, predicted_albedo)
#out_recon_color = self.color_model(out_recon)
return predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon
And this is the new model (I comment the part where it is different)
class SfSRGBNetPipeline(nn.Module):
""" SfSNet Pipeline
"""
def __init__(self):
super(SfSRGBNetPipeline, self).__init__()
self.conv_model = baseFeaturesExtractions()
self.normal_residual_model = NormalResidualBlock()
self.normal_gen_model = NormalGenerationNet()
self.albedo_residual_model = AlbedoResidualBlock()
self.albedo_gen_model = AlbedoGenerationNet()
self.light_estimator_model = LightEstimatorNew()
#punyaku
#self.color_model = ColorCorrector()
self.light_residual_model = LightFeatures()
self.light_gen_model = LightGenerationNet()
def get_face(self, sh, normal, albedo):
shading = get_shading(normal, sh)
recon = reconstruct_image(shading, albedo)
return recon
def forward(self, face):
# Following is training pipeline
# 1. Pass Image from Conv Model to extract features
out_features = self.conv_model(face)
# 2 a. Pass Conv features through Normal Residual
out_normal_features = self.normal_residual_model(out_features)
# 2 b. Pass Conv features through Albedo Residual
out_albedo_features = self.albedo_residual_model(out_features)
# 3 a. Generate Normal
predicted_normal = self.normal_gen_model(out_normal_features)
# 3 b. Generate Albedo
predicted_albedo = self.albedo_gen_model(out_albedo_features)
#Added new module here
out_light_features = self.light_residual_model(out_features)
predicted_correction = self.light_gen_model(out_light_features)
# 3 c. Estimate lighting
# First, concat conv, normal and albedo features over channels dimension
all_features = torch.cat((out_light_features, out_normal_features, out_albedo_features), dim=1)
# Predict SH
predicted_sh = self.light_estimator_model(all_features)
# 4. Generate shading
out_shading = get_shading(predicted_normal, predicted_sh)
# 5. Reconstruction of image
out_recon = reconstruct_image(out_shading, predicted_albedo)
#out_recon_color = self.color_model(out_recon)
return predicted_normal, predicted_albedo, predicted_correction, predicted_sh, out_shading, out_recon
and this is how I load the model
sfs_net_ori = SfsNetPipeline()
sfs_netku = SfSRGBNetPipeline()
if use_cuda:
sfs_net_model = nn.DataParallel(sfs_net_model)
sfs_net_model = sfs_net_model.cuda()
sfs_netku.apply(weights_init)
sfs_netku = nn.DataParallel(sfs_netku)
sfs_netku = sfs_netku.cuda()
checkpoint = torch.load(model_dir + 'sfs_net_model_9.pkl')
sfs_net_pretrained_old = checkpoint['model_state_dict']
sfs_net_pretrained_new = sfs_netku.state_dict()
sfs_netku.load_state_dict(sfs_net_pretrained_old)