Loading weights of specific modules in the model

I’ve two models (based on faster_rcnn from torchvision). I want to replace the weights of roi_head in model 2 with that of model 1’s.

My dict keys from model 1 has these keys:

 'roi_heads.box_head.fc6.weight'
 'roi_heads.box_head.fc6.bias'
 'roi_heads.box_head.fc7.weight'
 'roi_heads.box_head.fc7.bias'
 'roi_heads.box_predictor.cls_score.weight'
 'roi_heads.box_predictor.cls_score.bias'
 'roi_heads.box_predictor.bbox_pred.weight' 
 'roi_heads.box_predictor.bbox_pred.bias'

Here’s my model 1

ckpt = torch.load('saved_models/model1.pth')
model = get_model(15)
model.load_state_dict(ckpt['model'])

Model 2 is identical but roi_heads is different.How can I load weights of roi_heads (from model 1) to my model 2’s roi_head ?

@ptrblck Can you please provide some snippet?

How would you like to replace the roi_head parameters, if both roi_heads are different?
Are they just using different names or also different shapes of the parameters?

Both roi_head have same key names, it’s just the weights that are different. Even the shapes of in between layers is same. I just want to load the weights of roi_head with that of another model.

I want to just replace the weights of these :
'roi_heads.box_head.fc6.weight' 'roi_heads.box_head.fc6.bias' with that of another identical roi_head which will be used for taking predictions.

In that case something like this should work:

modelA.roi_head.load_state_dict(modelB.roi_head.state_dict())

Alternatively, remove the roi_heads strings from the checkpoint keys and try:

checkpoint = chpt['model']
# Remove 'roi_heads'
checkpoint_cleaned = ...
model.roi_head.load_state_dict(checkpoint_cleaned)
1 Like