Adjusting Keypoint R-CNN or Finetuning it end-to-end with a GCN


I’m using KEYPOINTRCNN_RESNET50_FPN from the torchvision library to estimate pose from RGB images. The model is successful but it lacks the ability to correlate keypoints together so it causes many flying points. One solution is to use graph-like networks to find relations between those keypoints. The problem is that the Keypoint RCNN in training mode is a standalone model trained separately from any other component and has its own loss function and it only produces losses. How can I add a graph network and train it end-to-end or adjust a layer within the model to be a graph network?

Do you have any advice on which layer I should adjust and how, or whether it’s possible to fine-tune this model with another one?

Thanks a lot.