Torchvision Object Detection Tutorial adding quantization step

Hi, I have been trying to implement a quantized mask rcnn for a project I am working on but I am not having much success. I have been following the Torchvision Object Detection Finetuning Tutorial here.

I have changed some of the code but the majority of it is still the same. I have implemented a class to wrap the model in a quantise/dequantise block and added to the get model function to quantise using a post static method. I have also tried to quantise just the backbone of the model with another class instead.

However, I have encountered the same error shown below. Can someone please help me, where am I going wrong?

My code can be found here.

Error code:


  File "/home/harry/Downloads/pendan/PennFudanPed/quantised_mask_rcnn_model.py", line 196, in main
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

  File "/home/harry/Downloads/pendan/PennFudanPed/engine.py", line 30, in train_one_epoch
    loss_dict = model(images, targets)

  File "/home/harry/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)

TypeError: forward() takes 2 positional arguments but 3 were given

Get Segmentation model function:

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    
    model.train()
    quant_mod = MQuantise(model)
    quant_mod.qconfig = torch.quantization.get_default_qconfig('qnnpack')
    torch.backends.quantized.engine = "qnnpack"
    model_static_quantized = torch.quantization.prepare_qat(quant_mod, inplace=True)
    return model_static_quantized

Quantise model class:

class MQuantise(torch.nn.Module):
    def __init__(self, model):
        super(MQuantise, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.model = model
        
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

Quantise backbone class:

class MQuantise_backbone(torch.nn.Module):
    def __init__(self, model):
        super(MQuantise_backbone, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.backbone = model.backbone
        self.rpn = model.rpn
        self.head = model.roi_heads
        
    def forward(self, x):
        x = self.quant(x)
        features_quant = self.backbone(x)
        features = self.dequant(features_quant)
        proposals = self.rpn(features)
        head_results = self.head(features, proposals)
        return head_results