MobilenetV2 quantization outputs all zeros

I’ve been trying to static quantize the mobilenetV2 model written by the PyTorch team.
Unfortunately, the model outputs all zeros and I’m not sure I understand where the problem is coming from …
Any help would be appreciated.
Code
The slightly modified mobilenetV2 code from PyTorch. Essentially, what I’ve changed is the forward method of the InvertedResidual block, and have used FloatFunctionals for the addition.

from torch import nn
from torch import Tensor
from typing import Callable, Any, Optional, List

__all__ = ['MobileNetV2', 'mobilenet_v2']

model_urls = {
    'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNActivation(nn.Sequential):
    def __init__(
            self,
            in_planes: int,
            out_planes: int,
            kernel_size: int = 3,
            stride: int = 1,
            groups: int = 1,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
            activation_layer: Optional[Callable[..., nn.Module]] = None,
            dilation: int = 1,
    ) -> None:
        padding = (kernel_size - 1) // 2 * dilation
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU6
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
                      bias=False),
            norm_layer(out_planes),
            activation_layer(inplace=True)
        )
        self.out_channels = out_planes


# necessary for backwards compatibility
ConvBNReLU = ConvBNActivation


class InvertedResidual(nn.Module):
    def __init__(
            self,
            inp: int,
            oup: int,
            stride: int,
            expand_ratio: int,
            norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers: List[nn.Module] = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            norm_layer(oup),
        ])
        self.conv = nn.Sequential(*layers)
        self.out_channels = oup
        self._is_cn = stride > 1

        self.floatFunctional = nn.quantized.FloatFunctional()

    def forward(self, x: Tensor) -> Tensor:
        if self.use_res_connect:
            return self.floatFunctional.add(x, self.conv(x))
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(
            self,
            num_classes: int = 1000,
            width_mult: float = 1.0,
            inverted_residual_setting: Optional[List[List[int]]] = None,
            round_nearest: int = 8,
            block: Optional[Callable[..., nn.Module]] = None,
            norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        """
        MobileNet V2 main class
        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
            block: Module specifying inverted residual building block for mobilenet
            norm_layer: Module specifying the normalization layer to use
        """
        super(MobileNetV2, self).__init__()

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        input_channel = 32
        last_channel = 1280

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
        x = self.features(x)
        # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
        x = self.classifier(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def mobilenet_v2() -> MobileNetV2:
    model = MobileNetV2()

    return model

The inputs that I give to the model are OK and they are in their correct format as I have used the original code from the PyTorch guide. The input is shaped like: [numberOfImages, 3, 224, 224].

The quantization and inference code is:

import mobilenet
import torch


class NewModel(torch.nn.Module):
  def __init__(self, model):
    super(NewModel, self).__init__()
    self.quant = torch.quantization.QuantStub()
    self.model = model
    self.dequant = torch.quantization.DeQuantStub()
    self.softmax = torch.nn.Softmax()
  
  def forward(self, x):
    # return self.softmax(self.dequant(self.model(self.quant(x))))
    return self.model(self.quant(x))

mobileNetModel = mobilenet.mobilenet_v2()
# mobileNetModel.eval()
# mobileNetModel.qconfig = torch.quantization.get_default_qconfig('fbgemm')
newModel = NewModel(mobileNetModel)
newModel.eval()
newModel.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# mQuan = torch.quantization.fuse_modules(model, [])
mQuan = torch.quantization.prepare(newModel)
mQuan(batchedInputs[1:,...])
mQuan = torch.quantization.convert(mQuan)
res = mQuan(batchedInputs[0:1, ...])
print(res)
for logit in res[0]:
  if logit!=0:
    print("yey")

As it is evident from the above block, I don’t fuse any parts since there are none to fuse (or that I don’t know how to fuse the given blocks…).
The output that I get is:

/usr/local/lib/python3.7/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."
/usr/local/lib/python3.7/dist-packages/torch/quantization/observer.py:990: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
  Returning default scale and zero point "
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       size=(1, 1000), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=1.1920928955078125e-07,
       zero_point=0)

What have I done wrong?

I don’t think it’s important, but nonetheless, I’ve run the code both on google colab cpu and on my own system. My system has the following related specs:
python 3.8.6
tensorboard 2.4.1
tensorboard-plugin-wit 1.8.0
thop 0.0.31.post2005241907
torch 1.7.0+cpu
torchaudio 0.7.0
torchvision 0.8.1+cpu

/usr/local/lib/python3.7/dist-packages/torch/quantization/observer.py:990: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point
Returning default scale and zero point "

This error message means that there are observers in the network which have not been calibrated. If observers are not calibrated, then they will use scale=1.0 and zero_point=0, which is probably not going to be useful. Could you verify that your calibration is working correctly, and all observers have scale and zero_point collected based on your calibration data?

it looks like you calibrate here. This may be an area to debug, you could try calibrating with more data, and then looking at the observers and verifying that all of them have collected statistics. If you print out the model, the observer statistics will be included.

I actually first thought that this might be a problem, but when I first searched for this specific error, I found this issue. With respect to the mentioned issue, this is just probably due to the fact the InvertedResidual block will follow one of two flows for calculation, and this warning is probably due to that.
I mean, when I print the model, it seems that the layers are properly quantized (not sure about this though since this is essentially my first quantized model and I don’t know what to expect). The print output is:

NewModel(
  (quant): Quantize(scale=tensor([0.0374]), zero_point=tensor([57]), dtype=torch.quint8)
  (model): MobileNetV2(
    (features): Sequential(
      (0): ConvBNActivation(
        (0): QuantizedConv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.03452397137880325, zero_point=58, padding=(1, 1), bias=False)
        (1): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): QuantizedReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.008563793264329433, zero_point=52, padding=(1, 1), groups=32, bias=False)
            (1): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.008242965675890446, zero_point=79, bias=False)
          (2): QuantizedBatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.004633823875337839, zero_point=58, bias=False)
            (1): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), scale=0.0007947214762680233, zero_point=75, padding=(1, 1), groups=96, bias=False)
            (1): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), scale=0.0007379651651717722, zero_point=64, bias=False)
          (3): QuantizedBatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (3): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=0.0004368616209831089, zero_point=63, bias=False)
            (1): QuantizedBatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), scale=6.097953882999718e-05, zero_point=67, padding=(1, 1), groups=144, bias=False)
            (1): QuantizedBatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), scale=8.161034202203155e-05, zero_point=62, bias=False)
          (3): QuantizedBatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=0.0007560973172076046, zero_point=64
          (activation_post_process): Identity()
        )
      )
      (4): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=0.00046657773782499135, zero_point=65, bias=False)
            (1): QuantizedBatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), scale=6.993811985012144e-05, zero_point=67, padding=(1, 1), groups=144, bias=False)
            (1): QuantizedBatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), scale=6.939686136320233e-05, zero_point=69, bias=False)
          (3): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (5): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=5.0384467613184825e-05, zero_point=58, bias=False)
            (1): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=4.988682576367864e-06, zero_point=65, padding=(1, 1), groups=192, bias=False)
            (1): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=4.922353127767565e-06, zero_point=70, bias=False)
          (3): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=7.048285624478012e-05, zero_point=68
          (activation_post_process): Identity()
        )
      )
      (6): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=3.9113976527005434e-05, zero_point=64, bias=False)
            (1): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=4.21840013586916e-06, zero_point=59, padding=(1, 1), groups=192, bias=False)
            (1): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=6.124475930846529e-06, zero_point=69, bias=False)
          (3): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=6.926279456820339e-05, zero_point=68
          (activation_post_process): Identity()
        )
      )
      (7): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=4.882919165538624e-05, zero_point=58, bias=False)
            (1): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), scale=6.18789727013791e-06, zero_point=65, padding=(1, 1), groups=192, bias=False)
            (1): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), scale=4.26427686761599e-06, zero_point=69, bias=False)
          (3): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (8): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=2.777898998829187e-06, zero_point=66, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=2.495579565220396e-07, zero_point=74, padding=(1, 1), groups=384, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), scale=2.758661423740705e-07, zero_point=70, bias=False)
          (3): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=4.288005584385246e-06, zero_point=69
          (activation_post_process): Identity()
        )
      )
      (9): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=3.1199779186863452e-06, zero_point=65, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=2.582654303751042e-07, zero_point=80, padding=(1, 1), groups=384, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), scale=2.4579065893703955e-07, zero_point=68, bias=False)
          (3): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=4.267273652658332e-06, zero_point=69
          (activation_post_process): Identity()
        )
      )
      (10): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=2.7751552806876134e-06, zero_point=67, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=2.0391455279877846e-07, zero_point=71, padding=(1, 1), groups=384, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), scale=2.1148737516796245e-07, zero_point=59, bias=False)
          (3): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=4.199701834295411e-06, zero_point=70
          (activation_post_process): Identity()
        )
      )
      (11): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=2.6447246455063578e-06, zero_point=64, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=2.6378887696409947e-07, zero_point=79, padding=(1, 1), groups=384, bias=False)
            (1): QuantizedBatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), scale=2.007121366887077e-07, zero_point=70, bias=False)
          (3): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (12): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=57, bias=False)
            (1): QuantizedBatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=5, padding=(1, 1), groups=576, bias=False)
            (1): QuantizedBatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=6, bias=False)
          (3): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.9502708425989113e-07, zero_point=72
          (activation_post_process): Identity()
        )
      )
      (13): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.2644522939808667e-07, zero_point=64, bias=False)
            (1): QuantizedBatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=6, padding=(1, 1), groups=576, bias=False)
            (1): QuantizedBatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=6, bias=False)
          (3): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.9090400371624128e-07, zero_point=72
          (activation_post_process): Identity()
        )
      )
      (14): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=1.250595147439526e-07, zero_point=60, bias=False)
            (1): QuantizedBatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), scale=1.1920928955078125e-07, zero_point=6, padding=(1, 1), groups=576, bias=False)
            (1): QuantizedBatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=5, bias=False)
          (3): QuantizedBatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (15): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=3, bias=False)
            (1): QuantizedBatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, padding=(1, 1), groups=960, bias=False)
            (1): QuantizedBatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, bias=False)
          (3): QuantizedBatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.1920928955078125e-07, zero_point=5
          (activation_post_process): Identity()
        )
      )
      (16): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=3, bias=False)
            (1): QuantizedBatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, padding=(1, 1), groups=960, bias=False)
            (1): QuantizedBatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, bias=False)
          (3): QuantizedBatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.1920928955078125e-07, zero_point=5
          (activation_post_process): Identity()
        )
      )
      (17): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNActivation(
            (0): QuantizedConv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=3, bias=False)
            (1): QuantizedBatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (1): ConvBNActivation(
            (0): QuantizedConv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, padding=(1, 1), groups=960, bias=False)
            (1): QuantizedBatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): QuantizedReLU6(inplace=True)
          )
          (2): QuantizedConv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, bias=False)
          (3): QuantizedBatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (floatFunctional): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (18): ConvBNActivation(
        (0): QuantizedConv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), scale=1.1920928955078125e-07, zero_point=0, bias=False)
        (1): QuantizedBatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): QuantizedReLU6(inplace=True)
      )
    )
    (classifier): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): QuantizedLinear(in_features=1280, out_features=1000, scale=1.1920928955078125e-07, zero_point=0, qscheme=torch.per_channel_affine)
    )
  )
  (dequant): DeQuantize()
  (softmax): Softmax(dim=None)
)

It could be that your model is sensitive to quantization. There is a prototype tool to help narrow this down to a particular layer: PyTorch Numeric Suite Tutorial — PyTorch Tutorials 1.7.1 documentation . One thing to try could be to run an example input through this tool and see if there is a particular problematic layer where things diverge.

Thanks
I tried the package and it did help a bit. I did have a problem trying to get the same kinds of outputs from the code similar to the guide but was unable to do so. I mean, Compare the weights of float and quantized models worked as expected but the other two guides from the tutorial didn’t, and the output dictionaries didn’t have the quantized activations.
The problem that I found is that the output layer of the original float model outputs values in the range 10^(-9) but the scale of the output tensor from the quantized model is 10^(-7) and thus if I understand correctly, 10^(-9) would be considered zero in this case.
How can I change this? I checked all the inputs that I fed to the model (for the observer) and all their outputs maxed out around 10^(-9) and it seems odd that the quantized model has chosen 10^(-7) as the scale!

I just wanted to let everybody know what the problem was. Hoping that nobody makes this foolish mistake :smile:
The issue was that I didn’t train my network. I mean, now that I load a pre-trained model and then quantize, the model output isn’t all zeros (I haven’t checked accuracy difference).
I hadn’t initially used a pre-trained model because since I had to change the underlying network a bit, I didn’t bother to modify the loading of state_dict and was just initializing the model randomly. I originally thought that this wouldn’t cause a problem because I thought PyTorch could quantize the activations well enough. But it seems that since the output is so random and not organized, the observer for the activations doesn’t work that well (the weights were correctly quantized as I had checked that the SNR of layers was good).