Detectron2 custom MaskRCNN loading

I have created a custom MaskRCNN head in detectron2 by adding an attention module to it but when I load it after training it gives me these warnings -

The checkpoint state_dict contains keys that are not used by the model:
  roi_heads.mask_head.bam1.channel_att.mlp.0.{bias, weight}
  roi_heads.mask_head.bam1.channel_att.mlp.2.{bias, weight}
  roi_heads.mask_head.bam1.spatial_att.layers.0.{bias, weight}
  roi_heads.mask_head.bam1.spatial_att.layers.1.{bias, num_batches_tracked, running_mean, running_var, weight}

I am using default config yaml file but the weights are of the trained model with custom MaskRCNN head. Alongwith this I have firstly defined the custom model head class of MaskRCNN, which I used for training, before loading the weights in the inference script.
When I do print the model, it shows all the layers or modules that I have added. I can do inference as well but these warnings are making me think that my inference model is not using calculated weights of the attention module? Is there a right/other way to load the custom model?

The right approach is to rebuild the model with the custom layers before loading the state_dict.
If the warning message is still raised you should check the attribute names and make sure they match the keys in the state_dict.

@ptrblck, thank you for your response. If I understood right, you mean I clone the git repo and make changes in the file where there is mask_head class defined by writing the attention module I am using and then build this again with following commands?

git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
python setup.py build develop

I was doing it in way where I was assigning my custom MaskHead class name to cfg like this cfg.MODEL.ROI_MASK_HEAD.NAME='MaskRCNNConvUpsampleHead_' when I was training. So I followed same process while inference and used this snippet cfg.MODEL.ROI_MASK_HEAD.NAME='MaskRCNNConvUpsampleHead_' before loading custom model state_dict. What is reason behind this approach not working or it is actually working? Thankyou.

No, you shouldn’t need to directly manipulate the repository but should just redo the same model manipulations again as seen in this small example:

model = models.resnet18()

# manpulate the model
model.fc = nn.Linear(512, 10)

# save state_dict
sd = model.state_dict()


# restore the model
model = models.resnet18()

# loading the state_dict into the original model will fail
model.load_state_dict(sd)
# RuntimeError: Error(s) in loading state_dict for ResNet:
# 	size mismatch for fc.weight: copying a param with shape torch.Size([10, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
# 	size mismatch for fc.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([1000]).

# manipulate the model in the same way
model.fc = nn.Linear(512, 10)

# loading the state_dict now works
model.load_state_dict(sd)
# <All keys matched successfully>
1 Like

I see. Got it. thank you @ptrblck .