Replace class method of a predefined torch model

I was wondering is there any possibility to change a function of a preloaded model, in my case:

model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=False,
                                                            num_classes=len(class_labels))
  • I have a predefined RetinaNet model with ResNet50 FPN
  • managed to update RetinaNetHead module, adding a new head to the model (custom loss and forward)
  • as far as I saw, RetinaNet(nn.Module) has a function called postprocess_detections, which I should modify, to have the detections from the new head

Is it possible to somehow subclass RetinaNet and overwrite that function, given the preloaded model or should I rewrite the whole class, then add FPN backbone?

If you understand the class method well, I don’t see any issue with overriding it with a subclass as that is the intended design.