I quantized input tensor and all layers except last layer, but got RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend

Happy new year!

I quantized UNet except last layer as follows, because I need full precision at the last layer.

class QuantizedUNet(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedUNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.until_last = copy.deepcopy(model_fp32)

        # Remove last layer from fp32 model and keep it in another variable
        del self.until_last.conv2[2]
        self.last_conv = model_fp32.conv2[2]

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.until_last(x)
        x = self.dequant(x)
        x = self.last_conv(x)
        return x

After static quantization and calibration, this is the model that i got.

QuantizedUNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (until_last): Unet(
    (down_sample_layers): ModuleList(
      (0): Sequential(
        (0): QuantizedConv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), scale=0.09462987631559372, zero_point=64, padding=(1, 1))
        (1): QuantizedBNReLU2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), scale=0.6255205273628235, zero_point=83, padding=(1, 1))
        (4): QuantizedBNReLU2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
      (1): Sequential(
        (0): QuantizedConv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), scale=1.403043270111084, zero_point=87, padding=(1, 1))
        (1): QuantizedBNReLU2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), scale=2.315826654434204, zero_point=60, padding=(1, 1))
        (4): QuantizedBNReLU2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
      (2): Sequential(
        (0): QuantizedConv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), scale=5.481112957000732, zero_point=56, padding=(1, 1))
        (1): QuantizedBNReLU2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=12.060239791870117, zero_point=77, padding=(1, 1))
        (4): QuantizedBNReLU2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
      (3): Sequential(
        (0): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=16.808162689208984, zero_point=69, padding=(1, 1))
        (1): QuantizedBNReLU2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=27.680782318115234, zero_point=80, padding=(1, 1))
        (4): QuantizedBNReLU2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
    )
    (conv): Sequential(
      (0): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=39.90061950683594, zero_point=66, padding=(1, 1))
      (1): QuantizedBNReLU2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Identity()
      (3): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=102.32366180419922, zero_point=65, padding=(1, 1))
      (4): QuantizedBNReLU2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Identity()
    )
    (up_sample_layers): ModuleList(
      (0): Sequential(
        (0): QuantizedConv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), scale=1064.0137939453125, zero_point=71, padding=(1, 1))
        (1): QuantizedBNReLU2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=1038.538330078125, zero_point=73, padding=(1, 1))
        (4): QuantizedBNReLU2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
      (1): Sequential(
        (0): QuantizedConv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), scale=3193.4365234375, zero_point=99, padding=(1, 1))
        (1): QuantizedBNReLU2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), scale=1721.619873046875, zero_point=87, padding=(1, 1))
        (4): QuantizedBNReLU2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
      (2): Sequential(
        (0): QuantizedConv2d(32, 8, kernel_size=(3, 3), stride=(1, 1), scale=2268.27001953125, zero_point=71, padding=(1, 1))
        (1): QuantizedBNReLU2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), scale=856.855712890625, zero_point=71, padding=(1, 1))
        (4): QuantizedBNReLU2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
      (3): Sequential(
        (0): QuantizedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), scale=493.1239318847656, zero_point=105, padding=(1, 1))
        (1): QuantizedBNReLU2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Identity()
        (3): QuantizedConv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), scale=84.60382080078125, zero_point=26, padding=(1, 1))
        (4): QuantizedBNReLU2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Identity()
      )
    )
    (conv2): Sequential(
      (0): QuantizedConv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), scale=15.952274322509766, zero_point=86)
      (1): QuantizedConv2d(4, 1, kernel_size=(1, 1), stride=(1, 1), scale=9.816205978393555, zero_point=58)
    )
  )
  (last_conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
)

It looks like all layers except last layer quantized properly as I expected.

But when I do inference with this model as follows, it raises runtime error at quantized modules part.

quantized_model.eval()
quantized_ouput = quantized_model(norm_input[0:1])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-127-34dec8af2a31> in <module>
      1 quantized_model.eval()
----> 2 quantized_ouput = quantized_model(norm_input[0:1])

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-118-db9442ff2b9c> in forward(self, x)
     14         # point to quantized in the quantized model
     15         x = self.quant(x)
---> 16         x = self.until_last(x)
     17         x = self.dequant(x)
     18         x = self.last_conv(x)

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/mnt/hdd2/jinwoo/airs_project/sample_forward/sample_forward/utils/unet.py in forward(self, input)
     38         # Apply down-sampling layers
     39         for layer in self.down_sample_layers:
---> 40             output = layer(output)
     41             stack.append(output)
     42             output = F.max_pool2d(output, kernel_size=2)

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/quantized/modules/conv.py in forward(self, input)
    329         if len(input.shape) != 4:
    330             raise ValueError("Input shape must be `(N, C, H, W)`!")
--> 331         return ops.quantized.conv2d(
    332             input, self._packed_params, self.scale, self.zero_point)
    333 

RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp:858 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradXLA: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
Tracer: fallthrough registered at /pytorch/torch/csrc/jit/frontend/tracer.cpp:967 [backend fallback]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

I suspected that not-quantized module could be the one that raises this error, so I only inferenced with quantized module.

#Except last unquantized layer
quantized_model.until_last.eval()
quantized_model.until_last(norm_input[0:1])

But it raises the same error.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-129-2ede05a56b57> in <module>
      1 quantized_model.until_last.eval()
----> 2 quantized_model.until_last(norm_input[0:1])

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/mnt/hdd2/jinwoo/airs_project/sample_forward/sample_forward/utils/unet.py in forward(self, input)
     38         # Apply down-sampling layers
     39         for layer in self.down_sample_layers:
---> 40             output = layer(output)
     41             stack.append(output)
     42             output = F.max_pool2d(output, kernel_size=2)

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.conda/envs/airs_project/lib/python3.9/site-packages/torch/nn/quantized/modules/conv.py in forward(self, input)
    329         if len(input.shape) != 4:
    330             raise ValueError("Input shape must be `(N, C, H, W)`!")
--> 331         return ops.quantized.conv2d(
    332             input, self._packed_params, self.scale, self.zero_point)
    333 

RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp:858 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradXLA: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
Tracer: fallthrough registered at /pytorch/torch/csrc/jit/frontend/tracer.cpp:967 [backend fallback]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

As I inspected, it seems like all layers are quantized properly in this module, too.

I searched google and pytorch docs thoroughly, but not sure what to debug for right now.

Any suggestions would be really welcome.

Thanks.

Please see quant docs: add common errors section by vkuzo · Pull Request #49902 · pytorch/pytorch · GitHub which was landed recently and adds documentation about this error. It looks like one of your layers expects an int8 tensor, but it is being passed an fp32 tensor. The fix would be to figure out where exactly in your model this is happening (it’s in the stack trace), and then add a QuantStub right before it. Also, the code which prepares and converts the model needs to include the QuantStub objects, as those are a part of the quantization flow.

Thanks! I figured it out with your help.

Now I’m wondering if it’s possible not to quantize input. Currently I’m doing super-resolution task, which needs input image at full precision. But quantized model needs quantized input. I’m looking forward to get advice.

Bests,

Hi KURI, hope you are fine.
I am facing almost the same issue as yours, so kindly help me with this
I have trained the model using fastai, and timm libararies.
Currently, I am doing following:

effb3_model=learner_effb3.model.eval()

backend = "qnnpack"

effb3_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(effb3_model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
print_size_of_model(model_static_quantized)

But I am facing following error, while calling the model for inference:

RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::thnn_conv2d_forward' is only available for these backends: [CPU, CUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

And this is my quantized_model:

Sequential(
  (0): Sequential(
    (0): Conv2dSame(3, 40, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (1): QuantizedBatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU(inplace=True)
    (3): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): QuantizedConv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=40)
          (bn1): QuantizedBatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(40, 10, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(10, 40, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pw): QuantizedConv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn2): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): Identity()
        )
        (1): DepthwiseSeparableConv(
          (conv_dw): QuantizedConv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=24)
          (bn1): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(24, 6, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(6, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pw): QuantizedConv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn2): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): Identity()
        )
      )
      (1): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144, bias=False)
          (bn2): QuantizedBatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(144, 6, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(6, 144, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=192)
          (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=192)
          (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(192, 192, kernel_size=(5, 5), stride=(2, 2), groups=192, bias=False)
          (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=288)
          (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=288)
          (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (3): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(288, 288, kernel_size=(3, 3), stride=(2, 2), groups=288, bias=False)
          (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (4): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (5): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(816, 816, kernel_size=(5, 5), stride=(2, 2), groups=816, bias=False)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (5): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (6): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 384, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(2304, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(2304, 2304, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=2304)
          (bn2): QuantizedBatchNorm2d(2304, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(2304, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(96, 2304, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(2304, 384, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (4): QuantizedConv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
    (5): QuantizedBatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (6): SiLU(inplace=True)
  )
  (1): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): Flatten(full=False)
    (2): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): QuantizedLinear(in_features=3072, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
    (5): ReLU(inplace=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5, inplace=False)
    (8): QuantizedLinear(in_features=512, out_features=73, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
  )
)

Thanks for any help.

@Muhammad_Ali
Hi Ali,
I am facing the same issue for this.
Have got a soluton for this?
Thanks.

Hi @bigtree , no I was not able to solve it.