How could I change the RPN Head structure?

Hello everyone,

I’m trying to use the weights of a pre-trained Mask RCNN with a FPN and ResNet101 as backbone. However, it uses a different RPN Head structure than the one provided by torchvision. While the one provided by torchvision has the following structure (256 is an input)

self.conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1, stride=1)

the one which I’m trying to use has the following structure

self.conv = nn.Conv2d(256, 512, kernel_size=3, stride=anchor_stride, padding=1)
self.cls_logits = nn.Conv2d(512, 2 * anchors_per_location, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(512, 4 * anchors_per_location, kernel_size=1, stride=1)

After I manually replace the RPN Head, I get an error due to the filter_proposals method (error bellow). How could I circumvent this situation, since, as it appears, the method is defined to a specific structure? Or am I missing something?

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-14-6ae18adb1800> in <module>
      1 mask_rcnn.to(device)
      2 mask_rcnn.eval()
----> 3 _ = mask_rcnn([img_torch])

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     68         if isinstance(features, torch.Tensor):
     69             features = OrderedDict([('0', features)])
---> 70         proposals, proposal_losses = self.rpn(images, features, targets)
     71         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
     72         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/rpn.py in forward(self, images, features, targets)
    470         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
    471         proposals = proposals.view(num_images, -1, 4)
--> 472         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
    473 
    474         losses = {}

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/rpn.py in filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level)
    384         objectness = objectness[batch_idx, top_n_idx]
    385         levels = levels[batch_idx, top_n_idx]
--> 386         proposals = proposals[batch_idx, top_n_idx]
    387 
    388         final_boxes = []

IndexError: index 361016 is out of bounds for dimension 1 with size 242991

What values does your implementation for anchors_per_location use and what is torchvision's num_anchors set to?
I’m not sure if the architecture is really different or if the naming was just changed.

It uses 3 anchors_per_location , but I think you’re right. It is a very similar implementation. However, while PyTorch’s implementation returns a logit tensor in the format [batch, anchors, height, width] , the one I am using returns a logit tensor in the format [batch, anchors per location * 2, height, width] . Any ideas why the double number of anchor logits? I’m still analyzing the code and I’ll give a feedback as soon as I have a (maybe partial) conclusion

Thanks