Quantized Conv2d bug

As my test, if input’s (dtype quint8) zero point is large, for example 128, the torch.nn.quantized.Conv2d will give a wrong result on Ubuntu 18.04 or windows 10. Some output feature map points match correct result, some output feature map points mismatch correct result, and the difference is much more than 1 or 2, is about 10 or 20).

If I set input’s zero point smaller as 75, the quantized conv2d become correct. However, different images have different range, for example, a image after quantized range [-54, 81]: maximum zero point is 75, [-66, 84]: maximum zero point is 69. No regular pattern.

However, on Centos7, the quantized Conv2d give a correct result. The scripts, python and pytorch and torchvision version are completely same.

So I have no idea why. Very appropriately, if anyone could help me.
My test CNN is ResNet50 and data set is Image Net ILSVRC2012.

could you provide some test case so that we can reproduce the problem?
cc @dskhudia @Zafar

First, I make use of the per-trained quantize ResNet50 model on torchvision.models.quantization.resnet

import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from torch._jit_internal import Optional

__all__ = ['QuantizableResNet', 'resnet50']

quant_model_urls = {
    'resnet50_fbgemm':
        'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth',
}


def _replace_relu(module):
    reassign = {}
    for name, mod in module.named_children():
        _replace_relu(mod)
        # Checking for explicit type instead of instance
        # as we only want to replace modules of the exact type
        # not inherited classes
        if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
            reassign[name] = nn.ReLU(inplace=False)

    for key, value in reassign.items():
        module._modules[key] = value


def quantize_model(model, backend):
    _dummy_input_data = torch.rand(1, 3, 299, 299)
    if backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported ")
    torch.backends.quantized.engine = backend
    model.eval()
    # Make sure that weight qconfig matches that of the serialized models
    if backend == 'fbgemm':
        model.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.default_observer,
            weight=torch.quantization.default_per_channel_weight_observer)
    elif backend == 'qnnpack':
        model.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.default_observer,
            weight=torch.quantization.default_weight_observer)

    model.fuse_model()
    torch.quantization.prepare(model, inplace=True)
    model(_dummy_input_data)
    torch.quantization.convert(model, inplace=True)

    return


class QuantizableBottleneck(Bottleneck):
    def __init__(self, *args, **kwargs):
        super(QuantizableBottleneck, self).__init__(*args, **kwargs)
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                            ['conv2', 'bn2', 'relu2'],
                            ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)


class QuantizableResNet(ResNet):

    def __init__(self, *args, **kwargs):
        super(QuantizableResNet, self).__init__(*args, **kwargs)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        x = self.dequant(x)
        return x

    def fuse_model(self):
        r"""Fuse conv/bn/relu modules in resnet models
        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
        for m in self.modules():
            if type(m) == QuantizableBottleneck:
                m.fuse_model()


def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs):
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = 'fbgemm'
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
            model_url = quant_model_urls[arch + '_' + backend]
        else:
            model_url = model_urls[arch]

        state_dict = load_state_dict_from_url(model_url,
                                              progress=progress)

        model.load_state_dict(state_dict)
    return model


def resnet50(pretrained=False, progress=True, quantize=False, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress,
                   quantize, **kwargs)

And then do inference for a image from Image Net

from PIL import Image
import torch.backends.quantized
import torchvision.transforms as transforms
from qresnet50_original import resnet50
import numpy as np

model = resnet50(pretrained=True, quantize=True)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

img = Image.open("n01440764-0-tench.JPEG")

img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# quantized input, and to test zero point effect, set zp 128
pre_quant = model.quant
pre_quant.zero_point = torch.tensor([128])
q_input = pre_quant(batch_t)

conv1_fp = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=True)
conv1_fp.weight = torch.nn.Parameter(model.conv1.weight().dequantize())
conv1_fp.bias = torch.nn.Parameter(model.conv1.bias())
relu = torch.nn.ReLU(inplace=True)

scale = model.conv1.scale
correct_output = np.round(relu(conv1_fp(q_input.dequantize()) / scale).detach().numpy())

output = np.round((model.conv1(q_input).detach().dequantize() / scale).numpy())
diff = np.abs(correct_output-output)
print(np.sum(diff))

If I change the pre_quant.zero_point to 128 (in fact larger than 75, will cause difference), the difference between correct output and quantization output will be large.
The test image is
n01440764-0-tench

My PC info: OS: Ubuntu 18.04, CPU: i7-9700K, Wrong ouput
Laptop info: OS: Windows 10, CPU: i7-8565U, Wrong ouput
My GPU Server: OS: CentOS 7.7, CPU: Intel® Xeon® Gold 6230, Correct output

PyTorch tested: 1.3.1, 1.4, 1.5. Python tested: 3.6, 3.7

Sorry that too late to reply. Wait for your help!

The scale and zero point for input are calculated after a calibration step across a representative set of images. By default, for PyTorch quantization, these are chosen such that quantized values are in the range [0, 127] for activations and zero_point is somewhere in the same range (e.g., 57 for the model here). We restrict quantized activation values to [0, 127] instead of [0, 255] for uint8 to avoid the saturation explained here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qconv.cpp#L633-L661

When you force the zero_point to 128, the quantized activations become in the range [0, 255] and it is possible that the saturation happens. BTW, any reason for forcing zero_point to a large value? Let the calibration step pick scale and zero_points for the model you are quantizing.

It is strange that you see different behavior on the cpus you mentioned. I would expect the same behavior on all the CPUs you mentioned.

Well, I try to use symmetric quantization scheme, which will force the zero point to 128, because asymmetric quantization mode will cause some other effects to our project. In this case, as I suppose, the range is [-128, 127], and after ReLU, the range become [0, 127]. This seems to be wrong. In fact, I prefer to using qint8 as feature map’s data type, but qint8 is not supported.
So can I say that in fbgemm, the qconv calculate the unsigned integer summation first and then minus summation with zero points? And the summation may cause saturation.
Also, the different behaviors confuse me…

For activations (uint8), the range is [0, 127] and after ReLU it stays the same. Which summation are you talking about? accumulations during convolution? accumulations (activations times weight matrix accumulations) are done in int32.

Okay, so the saturation happens during multiplication, but not the accumulations during convolution. But I suppose that before convolution, the uint8 range [0, 255] will minus zero point 128, which means the range will be [-128, 127]. However, it seems to be wrong.
Besides, when using nn.quantized.Conv2d but not nn.intrinsic.quantized.Conv2dReLU, the feature map output range should be [-128, 127] (symmetric quantization mode).
Actually, I am confused that why using uint8 not int8.

Is it true that doing multiplication before minus zero point?

Yes. It’s doing multiplication before subtracting zero_point.

Okay, so the saturation happens during multiplication, but not the accumulations during convolution.

The saturation I mentioned earlier, happens during multiplication. However, for large number of channels even the accumulations into an int32 might overflow. This should be very rare and should only happen for a very large number of channels. For such a case, there is no saturation but the accumulator overflows and wraps around.

Actually, I am confused that why using uint8 not int8.

We are using uint8 for activations because the x64 vector instruction to multiply two 8-bit integers (vpmaddubsw) accepts uint8 for one input and int8 for another. We chose uint8 for activations and int8 for weights. int8 for activation can also be supported by a preprocessing step, i.e., by converting int8 to uint8 and adjusting zero_point for activations (add 128 to all int8 elements and new_zero_point = old_zero_point + 128 in the preprocessing step).

Thank you very much for your explanation!

Has there been any update to this? There’s a “reduce_range” parameter which seems to fix the saturation issue discussed here. However, the reduce range functionality differs from standard int8 quantized convolutions, which makes predicting model performance and interfacing with other quantization libraries difficult. Thanks.

hi @amrmartini , just curious, what specifically would be helpful?

We do have plans to document the current behavior of reduce_range better and surface a warning if a user is using fbgemm with reduce_range set to False, to warn about potential overflow. We do not have plans at the moment to remove this restriction.

Hi all, hope you are fine.
I am stuck with the post-trainning quantization process:
I have trained the model using fastai, and timm libararies.
Currently, I am doing following:

effb3_model=learner_effb3.model.eval()

backend = "qnnpack"

effb3_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(effb3_model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
print_size_of_model(model_static_quantized)

But I am facing following error, while calling the model for inference:

RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::thnn_conv2d_forward' is only available for these backends: [CPU, CUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

This is my model:

Sequential(
  (0): Sequential(
    (0): Conv2dSame(3, 40, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (1): QuantizedBatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU(inplace=True)
    (3): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): QuantizedConv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=40)
          (bn1): QuantizedBatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(40, 10, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(10, 40, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pw): QuantizedConv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn2): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): Identity()
        )
        (1): DepthwiseSeparableConv(
          (conv_dw): QuantizedConv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=24)
          (bn1): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(24, 6, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(6, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pw): QuantizedConv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn2): QuantizedBatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): Identity()
        )
      )
      (1): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144, bias=False)
          (bn2): QuantizedBatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(144, 6, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(6, 144, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=192)
          (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=192)
          (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(192, 192, kernel_size=(5, 5), stride=(2, 2), groups=192, bias=False)
          (bn2): QuantizedBatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(192, 8, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(8, 192, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=288)
          (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=288)
          (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (3): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(288, 288, kernel_size=(3, 3), stride=(2, 2), groups=288, bias=False)
          (bn2): QuantizedBatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(288, 12, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(12, 288, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (4): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=576)
          (bn2): QuantizedBatchNorm2d(576, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(576, 24, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(24, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(576, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=816)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(136, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (5): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): Conv2dSame(816, 816, kernel_size=(5, 5), stride=(2, 2), groups=816, bias=False)
          (bn2): QuantizedBatchNorm2d(816, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(816, 34, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(34, 816, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(816, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (5): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0, padding=(2, 2), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(232, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (6): Sequential(
        (0): InvertedResidual(
          (conv_pw): QuantizedConv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(1392, 1392, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=1392)
          (bn2): QuantizedBatchNorm2d(1392, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(1392, 384, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): InvertedResidual(
          (conv_pw): QuantizedConv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn1): QuantizedBatchNorm2d(2304, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (conv_dw): QuantizedConv2d(2304, 2304, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), groups=2304)
          (bn2): QuantizedBatchNorm2d(2304, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act2): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): QuantizedConv2d(2304, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
            (act1): SiLU(inplace=True)
            (conv_expand): QuantizedConv2d(96, 2304, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          )
          (conv_pwl): QuantizedConv2d(2304, 384, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
          (bn3): QuantizedBatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (4): QuantizedConv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
    (5): QuantizedBatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (6): SiLU(inplace=True)
  )
  (1): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): Flatten(full=False)
    (2): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): QuantizedLinear(in_features=3072, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
    (5): ReLU(inplace=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5, inplace=False)
    (8): QuantizedLinear(in_features=512, out_features=73, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
  )
)

Thanks for any help.

1 Like