MaskRCNN with MobileNet backbone

I am trying to build a MaskRCNN model with MobileNetv2 backbone using mobilenet_backbone() function.

Here is my code:

from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.backbone_utils import mobilenet_backbone
 backbone = backbone_utils.mobilenet_backbone(
        backbone_name=backbone_name,
        pretrained=True,
        fpn=True)

model = MaskRCNN(backbone, num_classes)

Printed model architecture:

## skip the backbone
    (fpn): FeaturePyramidNetwork(                                                                          
      (inner_blocks): ModuleList(                                                                          
        (0): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))                                            
        (1): Conv2d(576, 256, kernel_size=(1, 1), stride=(1, 1))                                           
      )                                                                                                    
      (layer_blocks): ModuleList(                                                                          
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))                           
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))                           
      )                                                                                                    
      (extra_blocks): LastLevelMaxPool()                                                                   
    )                                                                                                      
  )
  (rpn): RegionProposalNetwork(                                                                            
    (anchor_generator): AnchorGenerator()                                                                  
    (head): RPNHead(                                                                                       
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))                          
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))                                      
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))                                      
    )                                                                                                      
  )                                                                                                        
  (roi_heads): RoIHeads(                                                                                   
    (box_roi_pool): MultiScaleRoIAlign()                                                                   
    (box_head): TwoMLPHead(                                                                                
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)                                       
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)                                        
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=2, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
    )
    (mask_roi_pool): MultiScaleRoIAlign()
    (mask_head): MaskRCNNHeads(
      (mask_fcn1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU(inplace=True)
      (mask_fcn2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU(inplace=True)
      (mask_fcn3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu3): ReLU(inplace=True)
      (mask_fcn4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu4): ReLU(inplace=True)
    )
    (mask_predictor): MaskRCNNPredictor(
      (conv5_mask): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
      (relu): ReLU(inplace=True)
      (mask_fcn_logits): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

But now when I try to do a forward call, I get the following error:

model.eval()
x = torch.randn(3, 128, 128)
model([x])
~/miniconda3/envs/torch-detect/lib/python3.8/site-packages/torchvision/models/detection/anchor_utils.py in grid_anchors(self, grid_sizes, strides)
    101 
    102         if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
--> 103             raise ValueError("Anchors should be Tuple[Tuple[int]] because each feature "
    104                              "map could potentially have different sizes and aspect ratios. "
    105                              "There needs to be a match between the number of "

ValueError: Anchors should be Tuple[Tuple[int]] because each feature map could potentially have different sizes and aspect ratios. There needs to be a match between the number of feature maps passed and the number of sizes / aspect ratios specified.

I realized that when a backbone is built with BackboneWithFPN(), I would need to make anchor_generator and specify the sizes and aspect_ratio with tuples (single-level, not nested tuples). However, when not using BackboneWithFPN(), the anchor_genertor needs a nested tuple; i.e. Tuple[Tuple[int]]. I have made the following script to play around mobilenet backbone with or without FPN:

"""
MobileNet Backbone for MaskRCNN
script: backbones.py
  - python -m backbones --use_fpn=0
  - python -m backbones --use_fpn=1
"""

import sys
import argparse

import torch
from torch import nn
import torchvision
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
from torchvision.models import mobilenet
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.detection.mask_rcnn import MaskRCNN

from ops import model_ops

def get_mobilenet_backbone(
        use_fpn: bool = True):

    # create the backbone
    bb = mobilenet.__dict__['mobilenet_v3_small'](
        pretrained=True,
        norm_layer=misc_nn_ops.FrozenBatchNorm2d)
    backbone = bb.features

    # backbone = torchvision.models.mobilenet_v2(pretrained=True).features
    backbone.out_channels = 1280

    stage_indices = [0] + [
        i for i, b in enumerate(backbone)
        if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    print(stage_indices)
    num_stages = len(stage_indices)
    trainable_layers = 2

    if trainable_layers == 0:
        freeze_before = num_stages
    else:
        freeze_before = stage_indices[num_stages - trainable_layers]

    for b in backbone[:freeze_before]:
        for parameter in b.parameters():
            parameter.requires_grad_(False)

    out_channels = 1280
    extra_blocks = LastLevelMaxPool()


    if use_fpn:
        returned_layers = [num_stages - 2, num_stages - 1]
        assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
        return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)}

        in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
        backbone = BackboneWithFPN(
            backbone, return_layers,
            in_channels_list, out_channels,
            extra_blocks=extra_blocks)
    else:
        backbone = nn.Sequential(
           backbone,
           # depthwise linear combination of channels to reduce their size
           nn.Conv2d(backbone[-1].out_channels, out_channels, 1))
        backbone.out_channels = out_channels

    return backbone


def main(args):
    backbone = get_mobilenet_backbone(args.use_fpn)
    x = torch.randn(3, 64, 64)
    out = backbone(x.unsqueeze(0))
    if args.use_fpn:
        print('Output of backbone with FPN:', out.keys(), [out[k].shape for k in out])
    else:
        print('Output of backbone without FPN:', out.shape)

    if args.use_fpn:
        anchor_generator = AnchorGenerator(
            sizes=(32, 64, 128),
            aspect_ratios=(0.5, 1.0, 2.0))
    else:
        anchor_generator = AnchorGenerator(
            sizes=((32, 64, 128),),
            aspect_ratios=((0.5, 1.0, 2.0),))

    box_roi_pooler = torchvision.ops.MultiScaleRoIAlign(
        featmap_names=['0'], output_size=7, sampling_ratio=2)

    mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(
        featmap_names=['0'], output_size=14, sampling_ratio=2)

    model = MaskRCNN(
        backbone, num_classes=2,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=box_roi_pooler,
        mask_roi_pool=mask_roi_pooler)

    print('model:', model)
    model.eval()
    out = model([x])
    key = 'masks'
    print('Output of model:', [f'\'masks\':{v[key].shape}' for v in out])


def parse(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--use_fpn', type=int, required=False,
        choices=[0, 1])

    args = parser.parse_args()
    args.use_fpn = bool(args.use_fpn)

    return args


if __name__ == '__main__':
    args = parse(sys.argv[1:])
    main(args)

I am looking into changing the mask_rcnn backbone with mobilenet v2. Is this complete code available in github repo?