Quantizing FasterRCNN PreTrained Model

Hi,

I am working on quantizing a FasterRCNN Model from pre-trained weights, and I was running into a couple issues regarding the FeaturePyramidNetwork layer. When trying to run my model, I am getting the following error since I am trying to work with quantized images on what I assume is a non quantized layer.

File “/home/maria/anaconda3/envs/engie/lib/python3.8/site-packages/torchvision/ops/feature_pyramid_network.py”, line 131, in forward
last_inner = inner_lateral + inner_top_down

RuntimeError: Could not run ‘aten::add.Tensor’ with arguments from the ‘QuantizedCPU’ backend. ‘aten::add.Tensor’ is only available for these backends: [CPU, CUDA, MkldnnCPU, SparseCPU, SparseCUDA, Meta, Named, Autograd, Profiler, Tracer].

Any thoughts on how I can fix this issue and run this quantized model successfully on quantized Tensor images?

Hi @Maria_Vazhaeparambil , we have a quick help page about this type of errors: Quantization — PyTorch master documentation

In this case, it looks like a quantized tensor is being passed to a floaing point kernel. You could fix it by a couple of ways:

  1. convert it to fp32 before passing to the layer (by passing through torch.quantization.DeQuantStub()
  2. or, if you are quantizing the network, use torch.quantization.FloatFunctional for adds

Thank you so much for your reply. I was able to use your fix to get my code up and running, but I realized that this method of statically quantizing my model means that all the computations will be run on the CPU. I read online that Quantization Aware Training allows for me to still train on the GPU, so I was hoping to try and use that instead.

I quantized my model using the following code:

class faster_countor_NN(Countor_NN):
    def __init__(self):
        super(faster_countor_NN, self).__init__()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()

    def forward(self, images, boxes, boxes_ids, ROI_images):
        det_boxes, det_scores, det_labels, tck_boxes, tck_scores, tck_labels, tck_ids = super().forward(images, boxes, boxes_ids, ROI_images)
        return det_boxes, det_scores, det_labels, tck_boxes, tck_scores, tck_labels, tck_ids

class QuantizedRCNN(torch.nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedRCNN, self).__init__()
        self.model_fp32 = model_fp32

    def forward(self, images, boxes, boxes_ids, ROI_images):
        det_boxes, det_scores, det_labels, tck_boxes, tck_scores, tck_labels, tck_ids = self.model_fp32(images, boxes, boxes_ids, ROI_images)
        return det_boxes, det_scores, det_labels, tck_boxes, tck_scores, tck_labels, tck_ids

init_net = faster_countor_NN()
init_net.load_state_dict(torch.load(paths['path_faster_RCNN_weigths'],map_location="cuda"), strict=False)
init_net.cpu()
init_net_fused = copy.deepcopy(init_net)
init_net.train()
init_net_fused.train()
init_net = torch.quantization.fuse_modules(init_net,[['conv', 'relu']])
init_net.eval()
init_net_fused.eval()
countor_net = QuantizedRCNN(model_fp32=init_net_fused)
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
countor_net.qconfig = quantization_config
torch.quantization.prepare_qat(countor_net, inplace=True)
countor_net.train()
countor_net.to('cpu')
countor_net = torch.quantization.convert(countor_net, inplace=True)
countor_net.cuda()
countor_net.eval()

However, I am still getting the following error in my code.

    File "/home/maria/anaconda3/envs/engie/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py", line 331, in forward
        return ops.quantized.conv2d(

RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CUDA' backend. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/build/aten/src/ATen/CPUType.cpp:2127 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370172916/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel] 
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/autograd/generated/TraceType_2.cpp:9654 [kernel]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370172916/work/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1607370172916/work/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370172916/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Even when I try to resolve this issue by sending my images back to the CPU (which I don’t want), I am getting a segmentation fault since my model is on the GPU. Do you know any way that I can resolve this issue to successfully run my model on the GPU? Is this possible, or am I not understanding the Quantizd Aware Training correctly?

Hi @Maria_Vazhaeparambil , this snippet is the part which is not supported. When you do torch.quantization.convert, the fp32 kernels get swapped to int8 kernels. There is currently no support to run int8 kernels on the GPU. If you’d like to evaluate a model with int8 kernels, it has to be done on CPU, so you would need to move your converted model to CPU.

I read this on the Pytorch website (Introduction to Quantization on PyTorch | PyTorch):

“However, quantization aware training occurs in full floating point and can run on either GPU or CPU. Quantization aware training is typically only used in CNN models when post training static or dynamic quantization doesn’t yield sufficient accuracy.”

Would this not be possible?

This means that QAT can run training on the GPU. Inference on a converted model is only supported on CPU.

Hi @Vasiliy_Kuznetsov. I too want to quantize a pretrained FasterRCNN model with MobilnetV3 backbone, but I don’t know where to start. Do I have to make a copy of and modify the original source code here? Let’s say I modified the MobileNetV3 backbone and added necessary “quant” lines, what changes do I need to make in the object detector part (RoiHeads, RPN, etc)? Thank you.

could you follow Quantization — PyTorch master documentation

Thank you @jerryzh168. So, I tried to quantize the mobilenetv3 backbone only. I modified the mobilenet_backbone() function to obtain the quantized version of the backbone as follows:

#backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
backbone = torchvision.models.quantization.mobilenet_v3_large(pretrained=False, quantize=False).features

Then, I followed this tutorial to create and train a FasterRCNN model. Before training, I fused the model backbone as follows:

# Create quantized model
fused_model = get_quantized_object_detection_model(num_classes, pretrained=True)
fused_model.to(cpu_device)
fused_model.train()
# Fuse layers
for m in fused_model.backbone.modules():
    if type(m) == ConvBNActivation:
        modules_to_fuse = ['0', '1']
        if type(m[2]) == nn.ReLU:
            modules_to_fuse.append('2')
        fuse_modules(m, modules_to_fuse, inplace=True)
    elif type(m) == QuantizableSqueezeExcitation:
        fuse_modules(m, ['fc1', 'relu'], inplace=True)
    elif type(m) == QuantizableInvertedResidual:
        for idx in range(len(m.block)):
            if type(m.block[idx]) == nn.Conv2d:
                fuse_modules(m.block, [str(idx), str(idx + 1)], inplace=True)

Then, prepared the model for quantization aware training:

backend = 'fbgemm'
fused_model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(fused_model.backbone, inplace=True)
fused_model.to(cuda_device)

    # ... creating optimizer and training

fused_model.to(cpu_device)
torch.quantization.convert(fused_model.backbone, inplace=True)
fused_model.eval()

    # ... save model

During preparation I did not get any error, but after quantization the model size did not decrease (74MB) and inference speed and accuracy decreased. Obviously something is not right. Any thoughts?

could you print the final quantized model, does it contain any quantized modules?

Here is the output of print(fused_model):

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(320,), max_size=640, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (0): ConvBNActivation(
        (0): ConvBn2d(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
        (1): Identity()
        (2): Hardswish()
      )
      (1): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (2): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (1): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
              (1): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (2): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (3): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (1): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
              (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (2): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (4): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (1): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False)
              (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (5): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (1): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
              (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (6): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (1): ConvBNActivation(
            (0): ConvBnReLU2d(
              (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
              (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU()
            )
            (1): Identity()
            (2): Identity()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (7): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False)
              (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (8): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(200, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False)
              (1): BatchNorm2d(200, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (9): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
              (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (10): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
              (1): BatchNorm2d(184, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(80, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (11): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(480, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False)
              (1): BatchNorm2d(480, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(112, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (12): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False)
              (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(112, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (13): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False)
              (1): BatchNorm2d(672, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (14): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
              (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (15): QuantizableInvertedResidual(
        (block): Sequential(
          (0): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (1): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
              (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Hardswish()
          )
          (2): QuantizableSqueezeExcitation(
            (fc1): ConvReLU2d(
              (0): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
              (1): ReLU()
            )
            (relu): Identity()
            (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
            (skip_mul): FloatFunctional(
              (activation_post_process): Identity()
            )
          )
          (3): ConvBNActivation(
            (0): ConvBn2d(
              (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(160, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
            (1): Identity()
            (2): Identity()
          )
        )
        (skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (16): ConvBNActivation(
        (0): ConvBn2d(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
        (1): Identity()
        (2): Hardswish()
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2d(160, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(960, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (layer_blocks): ModuleList(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (extra_blocks): LastLevelMaxPool()
    )
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (cls_logits): Conv2d(256, 15, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): TwoMLPHead(
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=2, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
    )
  )
)

looks like there is no quantized modules in backbone, maybe there is something wrong in the previous steps