Slow Inference time on Quantized Faster RCNN model

I had applied QuantWrapper() on a pre-trained FasterRCNN model with mobilenet v3-320 backbone. Though the model size has reduced to 25% of the original size, the inference time is approximately the same.

QuantWrapper(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (module): QuantWrapper(
    (quant): QuantStub()
    (dequant): DeQuantStub()
    (module): FasterRCNN(
      (transform): GeneralizedRCNNTransform(
          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
          Resize(min_size=(320,), max_size=640, mode='bilinear')
      )
      (backbone): BackboneWithFPN(
        (body): IntermediateLayerGetter(
          (0): ConvBNActivation(
            (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(16, eps=1e-05)
            (2): Hardswish()
          )
          (1): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
                (1): FrozenBatchNorm2d(16, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(16, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (2): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(64, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
                (1): FrozenBatchNorm2d(64, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (2): ConvBNActivation(
                (0): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(24, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (3): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(72, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
                (1): FrozenBatchNorm2d(72, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (2): ConvBNActivation(
                (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(24, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (4): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(72, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False)
                (1): FrozenBatchNorm2d(72, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(40, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (5): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(120, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
                (1): FrozenBatchNorm2d(120, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(40, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (6): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(120, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
                (1): FrozenBatchNorm2d(120, eps=1e-05)
                (2): ReLU(inplace=True)
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(40, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (7): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(240, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False)
                (1): FrozenBatchNorm2d(240, eps=1e-05)
                (2): Hardswish()
              )
              (2): ConvBNActivation(
                (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(80, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (8): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(200, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False)
                (1): FrozenBatchNorm2d(200, eps=1e-05)
                (2): Hardswish()
              )
              (2): ConvBNActivation(
                (0): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(80, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (9): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(184, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
                (1): FrozenBatchNorm2d(184, eps=1e-05)
                (2): Hardswish()
              )
              (2): ConvBNActivation(
                (0): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(80, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (10): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(184, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)
                (1): FrozenBatchNorm2d(184, eps=1e-05)
                (2): Hardswish()
              )
              (2): ConvBNActivation(
                (0): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(80, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (11): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(480, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False)
                (1): FrozenBatchNorm2d(480, eps=1e-05)
                (2): Hardswish()
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(112, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (12): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(672, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False)
                (1): FrozenBatchNorm2d(672, eps=1e-05)
                (2): Hardswish()
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(112, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (13): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(672, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False)
                (1): FrozenBatchNorm2d(672, eps=1e-05)
                (2): Hardswish()
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(160, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (14): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(960, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
                (1): FrozenBatchNorm2d(960, eps=1e-05)
                (2): Hardswish()
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(160, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (15): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(960, eps=1e-05)
                (2): Hardswish()
              )
              (1): ConvBNActivation(
                (0): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)
                (1): FrozenBatchNorm2d(960, eps=1e-05)
                (2): Hardswish()
              )
              (2): SqueezeExcitation(
                (fc1): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))
                (relu): ReLU(inplace=True)
                (fc2): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))
              )
              (3): ConvBNActivation(
                (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): FrozenBatchNorm2d(160, eps=1e-05)
                (2): Identity()
              )
            )
          )
          (16): ConvBNActivation(
            (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(960, eps=1e-05)
            (2): Hardswish()
          )
        )
        (fpn): FeaturePyramidNetwork(
          (inner_blocks): ModuleList(
            (0): Conv2d(160, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(960, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_blocks): ModuleList(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (extra_blocks): LastLevelMaxPool()
        )
      )
      (rpn): RegionProposalNetwork(
        (anchor_generator): AnchorGenerator()
        (head): RPNHead(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (cls_logits): Conv2d(256, 15, kernel_size=(1, 1), stride=(1, 1))
          (bbox_pred): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (roi_heads): RoIHeads(
        (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
        (box_head): TwoMLPHead(
          (fc6): Linear(in_features=12544, out_features=1024, bias=True)
          (fc7): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (box_predictor): FastRCNNPredictor(
          (cls_score): Linear(in_features=1024, out_features=91, bias=True)
          (bbox_pred): Linear(in_features=1024, out_features=364, bias=True)
        )
      )
    )
  )
)

Any suggestions?

one good next step would be to run per-op profiling (PyTorch Profiler — PyTorch Tutorials 1.8.0 documentation) on both the fp32 and quantized versions of your model, to see which kernels are contributing the most to inference time

Thank you for the reply.
I’m however getting an error on the quantized model

Could not run ‘aten::quantize_per_tensor’ with arguments from the ‘QuantizedCPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit Internal Login for possible resolutions. ‘aten::quantize_per_tensor’ is only available for these backends: [CPU, CUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

this means you are quantizing a already quantized Tensor, it’s because QuantStub is not placed correctly in the original model, you might have a Tensor being quantized multiple times.]

e.g.
x = self.quant(x)
x = self.quant2(x)

looks like you have quantwrapper over the original model and the fasterrcnn model, this would mean the input of FasterRCNN model is quantized twice.

Do you think it’s because of the

(quant): QuantStub()
(dequant): DeQuantStub()
(module): QuantWrapper(
(quant): QuantStub()
(dequant): DeQuantStub()

Portion?
Anyways, now I fixed it, and found this error while using the Profiler

> Expected input images to be of floating type (in range [0, 1]), but found type torch.quint8 instead

@jerryzh168 @Vasiliy_Kuznetsov
Irrespective of the above error, I still don’t understand how the inference time is almost the same despite the proper modules are shown to be quantized when it is printed

` #Nomal model
model_inference 5.31% 24.628ms 99.99% 463.702ms 463.702ms

#Static quantized model
model_inference 4.93% 22.530ms 99.99% 456.504ms 456.504ms

#ROI quantized model (dynamically quantizing the Linear layers)
model_inference 5.31% 24.628ms 99.99% 461.05ms 461.05ms

which hardware (server CPU/mobile CPU) and quantized engine (fbgemm/qnnpack) are you running this with?