Hey all,
I would like to use the classification weights I obtained by running classification on ResNet50 and Mobilenetv3. The dataset for segmentation is very similar to the classification. I would like to transfer the weights to fpn network of the Mask RCNN without calling the pre-trained backbone or model. I did as below but not sure if it is correct.
def maskrcnn_50fpn_transfer(num_classes=2, pretrained_backbone=False, path=None):
"""
:param num_classes: num_classes + background
:param pretrained_backbone: pretrained backbone
:param path: path to weights from classification
:return: Maskrcnn model with fpn trained on classification
"""
resnet50 = torchvision.models.resnet50(pretrained=False)
# Resnet was orinally pretrained on three classes
resnet50.fc = nn.Linear(in_features=2048, out_features=3)
print('Loading covid weights from classification')
checkpoint = torch.load(path, map_location=torch.device('cpu'))
resnet50.load_state_dict(checkpoint['model'], strict=False)
backbone = resnet_fpn_backbone('resnet50', pretrained=resnet50)
model = MaskRCNN(backbone, num_classes)
return model
Any hint would be greatly appreciated
Az