Static quantization for YOLOv5 model

Hello,

I am trying to statically quantize the YOLOv5 model. A link to the repo is: GitHub - ultralytics/yolov5: YOLOv5 in PyTorch > ONNX > CoreML > TFLite. I am loading the model into a nn.Module container class in order to apply the quantization and dequantization stubs. The code looks like this:

class QuantizationModule(nn.Module):
    def __init__(self, model):
        super(QuantizationModule, self).__init__()
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x


model = QuantizationModule(model) # model here is loaded by the code in repo in export.py
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.backends.quantized.engine = "qnnpack"
model_static_quantized = torch.quantization.prepare(model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)

img = torch.zeros(1, 3, 640, 640).to(device)
ts = torch.jit.trace(model_static_quantized, img)

However, there seems to be an issue:

TorchScript export failure: Could not run ‘aten::mul.Tensor’ with arguments from the ‘QuantizedCPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. ‘aten::mul.Tensor’ is only available for these backends: [CPU, CUDA, MkldnnCPU, SparseCPU, SparseCUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:5925 [kernel]
CUDA: registered at /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7100 [kernel]
MkldnnCPU: registered at /pytorch/build/aten/src/ATen/RegisterMkldnnCPU.cpp:284 [kernel]
SparseCPU: registered at /pytorch/build/aten/src/ATen/RegisterSparseCPU.cpp:557 [kernel]
SparseCUDA: registered at /pytorch/build/aten/src/ATen/RegisterSparseCUDA.cpp:655 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: fallthrough registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradNestedTensor: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:8707 [autograd kernel]
Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_4.cpp:10612 [kernel]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:250 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:1020 [kernel]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Do you have any idea why this might be? Is it an unsupported operation in the backend or is there something wrong with my code? I have to mention that forwarding the image through the quantized module does give me the right results. Any help is much appreciated, thanks!

I have to mention that forwarding the image through the quantized module does give me the right results.

to clarify, does this mean that you are able to run forward on your quantized model, you just cannot trace it? Knowing the answer to this will help us debug further.

hey Vasily, yes that is the issue.

Hi @vladsb94 , I checked out yolov5 and applied the code you provided (gist:d863f53c8809198b3e0a4fd2af1563a7 · GitHub). I’m not sure how you got things to work without additional changes. This model uses the SiLU activation which does not have an int8 kernel. To make this model quantizeable to int8, there are a couple of options:

  1. add an int8 kernel for SiLU (we would happily accept a PR)
  2. add a custom quantizeable module for SiLU (it can just do dequant → SiLU → quant)
  3. make yolov5 symbolically traceable and use FX graph mode quantization
3 Likes

Hi @Vasiliy_Kuznetsov, thanks for taking the time to answer and debug the model.

I have tried method 2, however there are other operations within the model that require dequant and quant and I end up modifying too much of the code.

Moreover, I tried method 3, but I still have some issues: here is a snippet of the code used for export using fx graph mode quantization, I put it in export.py right after line 61:

model_to_quantize = copy.deepcopy(model)
model_to_quantize.eval()
qconfig = torch.quantization.get_default_qconfig('qnnpack')
qconfig_dict = {"": qconfig}

from torch.quantization.quantize_fx import prepare_fx, convert_fx

prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
quantized_model = convert_fx(prepared_model)
yy = quantized_model(img)

but then, when it’s being traced ts = torch.jit.trace(quantized_model, img) I get a very strange output which I find difficult to interpret. Also the predictions performed with the quantized model are not correct, I get a lot’s of small bounding boxes. It looks something like this (I won’t copy all of it here because it’s very large):

<eval_with_key_5>:7: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  quantize_per_tensor_1 = torch.quantize_per_tensor(getitem, model_0_input_scale_0, model_0_input_zero_point_0, model_0_input_dtype_0);  getitem = model_0_input_scale_0 = model_0_input_zero_point_0 = model_0_input_dtype_0 = None
<eval_with_key_5>:12: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  quantize_per_tensor_2 = torch.quantize_per_tensor(getitem_1, model_0_input_scale_1, model_0_input_zero_point_1, model_0_input_dtype_1);  getitem_1 = model_0_input_scale_1 = model_0_input_zero_point_1 = model_0_input_dtype_1 = None
...

Maybe you can help me with the source of these warnings or at least point me in the right direction. Thanks for your time!

Hi @Vasiliy_Kuznetsov , hope you are fine.
I am facing the same issue, and there are some layers that are not being quantized, like
SiLU,Batchnorm1d etc.
Here you can see:

QuantizationModule(
  (model): Sequential(
    (0): Sequential(
      (0): QuantizedConv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0)
      (1): QuantizedBatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
      (3): Sequential(
        (0): Sequential(
          (0): DepthwiseSeparableConv(
            (conv_dw): QuantizedConv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=40)
            (bn1): QuantizedBatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(40, 10, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(10, 40, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pw): QuantizedConv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn2): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): Identity()
          )
          (1): DepthwiseSeparableConv(
            (conv_dw): QuantizedConv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=24)
            (bn1): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(24, 6, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(6, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pw): QuantizedConv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn2): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): Identity()
          )
        )
        (1): Sequential(
          (0): InvertedResidual(
            (conv_pw): QuantizedConv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): Conv2dSame(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144, bias=False)
            (bn2): QuantizedBatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(144, 6, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(6, 144, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): InvertedResidual(
            (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=192)
            (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (2): InvertedResidual(
            (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=192)
            (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (2): Sequential(
          (0): InvertedResidual(
            (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): Conv2dSame(192, 192, kernel_size=(5, 5), stride=(2, 2), groups=192, bias=False)
            (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): InvertedResidual(
            (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=288)
            (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (2): InvertedResidual(
            (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=288)
            (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (3): Sequential(
          (0): InvertedResidual(
            (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): Conv2dSame(288, 288, kernel_size=(3, 3), stride=(2, 2), groups=288, bias=False)
            (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): InvertedResidual(
            (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
            (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (2): InvertedResidual(
            (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
            (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (3): InvertedResidual(
            (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
            (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (4): InvertedResidual(
            (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
            (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (4): Sequential(
          (0): InvertedResidual(
            (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=576)
            (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(576, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): InvertedResidual(
            (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
            (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (2): InvertedResidual(
            (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
            (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (3): InvertedResidual(
            (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
            (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (4): InvertedResidual(
            (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
            (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (5): Sequential(
          (0): InvertedResidual(
            (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): Conv2dSame(816, 816, kernel_size=(5, 5), stride=(2, 2), groups=816, bias=False)
            (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(816, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): InvertedResidual(
            (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
            (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (2): InvertedResidual(
            (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
            (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (3): InvertedResidual(
            (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
            (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (4): InvertedResidual(
            (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
            (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (5): InvertedResidual(
            (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
            (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (6): Sequential(
          (0): InvertedResidual(
            (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=1392)
            (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(1392, 384, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): InvertedResidual(
            (conv_pw): QuantizedConv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn1): QuantizedBatchNorm2d(2304, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (conv_dw): QuantizedConv2d(2304, 2304, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=2304)
            (bn2): QuantizedBatchNorm2d(2304, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act2): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): QuantizedConv2d(2304, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
              (act1): SiLU(inplace=True)
              (conv_expand): QuantizedConv2d(96, 2304, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            )
            (conv_pwl): QuantizedConv2d(2304, 384, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (bn3): QuantizedBatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (4): QuantizedConv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
      (5): QuantizedBatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (6): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): AdaptiveConcatPool2d(
        (ap): AdaptiveAvgPool2d(output_size=1)
        (mp): AdaptiveMaxPool2d(output_size=1)
      )
      (1): Flatten(full=False)
      (2): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
      (4): QuantizedLinear(in_features=3072, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (5): ReLU(inplace=True)
      (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): Dropout(p=0.5, inplace=False)
      (8): QuantizedLinear(in_features=512, out_features=73, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
    )
  )
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

Now I can do filter out SiLU, by following:

for name, layer in model_static_quantized2.named_modules():
    if isinstance(layer, nn.SiLU):
        print(name, layer)

But, now how can I specifically quant, and dequant the layers in the model.

Thanks for any help…

1 Like

Hi @Muhammad_Ali ,hope you are doing well,
Did you find solution for that.
Thanks in advance.