How to static quantization? I have some questions after reading the tutorial

I quantified the model based on the tutorial, but the model’s error is a bit large. Therefore, I try to quantify the data myself, then dequantize it, and finally calculate the convolution, and find that this error is smaller than the error of the tutorial method. what is the reason? Did I miss some steps when quantifying the model? The following is a demo.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

torch.backends.quantized.engine = 'qnnpack'

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.conv = nn.Conv2d(132, 121, 3, 1, bias=False)
        self.quan = QuantStub()
        self.dequan = DeQuantStub()
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out')

    def forward(self, x):
        out = self.quan(x)
        out = self.conv(out)
        out = self.dequan(out)
        return out

def test():
    model = Test()
    weight = model.conv.weight
    conv = nn.Conv2d(132, 121, 3, 1, bias=False)
    conv.weight = weight
    max, min = torch.max(weight.flatten()), torch.min(weight.flatten())
    scale = (max - min) / 256
    temp = torch.quantize_per_tensor(weight, scale=scale.item(), zero_point=0, dtype=torch.qint8)
    temp_de = temp.dequantize()
    example = torch.rand(1, 132, 64, 64)
    print(F.mse_loss(temp_de, weight))
    # tensor(1.9924e-07, grad_fn=<MeanBackward0>)
    print(F.mse_loss(F.conv2d(example, weight=temp_de, bias=None, stride=1), conv(example)))
    # tensor(8.7799e-05, grad_fn= < MeanBackward0 >)
    torch.backends.quantized.engine = 'qnnpack'
    model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    print(F.mse_loss(conv(example), model(example)))
    # tensor(0.4824, grad_fn= < MseLossBackward >)


if __name__ == "__main__":
    test()
    print(torch.__version__)
    # 1.5.0

In the first comparison, you are not quantizing the activations (i.e the example tensor). When you call the torch.quantization.prepare and convert, the activations are quantized too. In this case, since you are skipping calibration (i.e you are calling convert after prepare), the activations are quantized with a default scale of 1.0, leading to a large quantization error. Please repeat this experiment with calibration and you should see lower error

Thanks. I konw that I need to run the model before convernting it.