How to load weights of custom dataset onto fpn for segmentation

Hey all,

I would like to use the classification weights I obtained by running classification on ResNet50 and Mobilenetv3. The dataset for segmentation is very similar to the classification. I would like to transfer the weights to fpn network of the Mask RCNN without calling the pre-trained backbone or model. I did as below but not sure if it is correct.

def maskrcnn_50fpn_transfer(num_classes=2, pretrained_backbone=False, path=None):
    :param num_classes: num_classes + background
    :param pretrained_backbone: pretrained  backbone
    :param path: path to weights from classification
    :return: Maskrcnn model with fpn trained on classification

    resnet50 = torchvision.models.resnet50(pretrained=False)

    # Resnet was orinally  pretrained on three classes
    resnet50.fc = nn.Linear(in_features=2048, out_features=3)

    print('Loading covid weights from classification')
    checkpoint = torch.load(path, map_location=torch.device('cpu'))

    resnet50.load_state_dict(checkpoint['model'], strict=False)

    backbone = resnet_fpn_backbone('resnet50', pretrained=resnet50)

    model = MaskRCNN(backbone, num_classes)

    return model

Any hint would be greatly appreciated