Speedup keypoint RCNN Inference

I have a custom keypoint detection framework which I train using a few thousand coco annotated samples of everyday object images. The model works pretty well in predicting both keypoints and the bounding boxes, the training takes just a few minutes, but the inference is quite slow. The trained model which I save is about 250MB, (if I use model.half() for fp16 accuracy, it is about 130MB)

The train and load code is pretty straight forward.

anchor_generator = AnchorGenerator(
    sizes=(32, 64, 128, 256, 512),
    aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)

# Create the keypoint RCNN model
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
    num_classes=4,  # Background is the first class, object is the second class

 # save model
 # load model from path later for inference

  # model = # load from path
  # Load pre-trained weights if available
  state_dict = torch.load(weights_path)

The tensor for which I want to perform inference is of [512 x 3 x 128 x 128], I.e. 512 RGB images of 128x128 pixels. I am batching like below to avoid Mem issues on my RTX4090.

  with torch.no_grad():
   rcnn_batch_size = 16

    # Loop over the list and perform predictions in batches
    with torch.no_grad():
        predictions = []
        for i in range(0, len(side_cam_rgb_tensors), rcnn_batch_size):
            batch_rgb_tensors = side_cam_rgb_tensors[i:i + rcnn_batch_size]
            batch_rgb_tensors = torch.stack(batch_rgb_tensors).to(device)
            batch_predictions = model(batch_rgb_tensors) # This is the slow statement that takes about 30 sec.


    for p in predictions:
	        # Take the p['scores'], p['keypoints'] and transform it appropriately

I was wondering if there are any steps I can do to speedup the inference process, perhaps with some trade off for accuracy or training time. Changing precision of model/input to fp16 improves the speed by about 10%. I have also tried to do torch.compile(model), which however results in an issue similar to this Compilation hangs: too many threads. Any suggestion is appreciated. Thanks