The result of quantized conv2d is different from the result I calculate

Dear pytorch forum, I find that result of quantized conv2d is different from what I calculate.

First, I import necessary package:

import torch
import torch.nn as nn

And the build a simple convolution model:

class CustomModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.dequant(x)
        return x

net = CustomModel()
net.eval()

Configuration for quantize model:

my_qconfig = torch.quantization.qconfig.QConfig(
        activation=torch.quantization.observer.HistogramObserver.with_args(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True),
        weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False )
)

net.qconfig = my_qconfig
torch.backends.quantized.engine = "fbgemm"
torch.quantization.prepare(net, inplace=True)

Calibrate model with random generated data:

calibrate_data = torch.randint(low=0, high=255, size=(1, 4, 16), dtype=torch.uint8).unsqueeze(0)
calibrate_data = calibrate_data / 255
_ = net(calibrate_data)

Quantize model:

torch.quantization.convert(net, inplace=True)

Output of quantized model:

CustomModel(
  (quant): Quantize(scale=tensor([0.0078]), zero_point=tensor([0]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.004018464591354132, zero_point=0)
  (dequant): DeQuantize()
)

I can know that:

  • Input Scale = 0.0078
  • Input Zero Point = 0
  • Output Scale of Quantized Conv2d = 0.0040
  • Output Zero Point of Quantized Conv2d = 0

I create a sample input:

channel1 = torch.arange(0, 64).view(4, 16).to(torch.uint8).unsqueeze(0)
input_data = channel1.unsqueeze(0)
input_data = input_data / 255

Register forward hook to all modules because I want to save activation in model:

activations = []
def custom_hook(module, input, output):
    info = {
        'module': module,
        'input': input,
        'output': output
    }
    activations.append(info)

for name, module in net.named_modules():
    if len(list(module.children())) == 0:
        module.register_forward_hook(custom_hook)

Feed input data to model:

_ = net(input_data)

Extract the input and output of quantized conv2d:

activations[1]['input'][0].int_repr()
activations[1]['output'].int_repr()
# input
tensor([[[[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8],
          [ 8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16],
          [16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24],
          [24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32]]]],
       dtype=torch.uint8)

# output
tensor([[[[ 8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 13, 13, 14, 14, 15],
          [15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23],
          [23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 28, 28, 29, 29, 30],
          [30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38]]]],
       dtype=torch.uint8)

I find that if the input value is 24, the output will be 30.

I check the bias, weight of quantized conv2d and its integer representation:

net.conv1.weight()
net.conv1.weight().int_repr()
net.conv1.bias()
# weight
tensor([[[[0.4808]]]], size=(1, 1, 1, 1), dtype=torch.qint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.0038], dtype=torch.float64), zero_point=tensor([0]),
       axis=0)

# int repr
tensor([[[[127]]]], dtype=torch.int8)

# bias
tensor([0.0317], requires_grad=True)

I can know that:

  • Weight Scale = 0.0038
  • Weight Zero Point = 0
  • Weight = 127
  • Bias (before quantized) = 0.0317

After I get all these information, I try to reproduce the output (30) from input (24), but it fails:
(calculation formula is from here provided by @jerryzh168)

q_x = 24
z = q_x * 127  # z = q_x * w
bias_q = round(0.0317/(0.0038*0.0078))  # bias_q = round(bias/(weight scale * input scale))
z_int = z + bias_q
z_out = z_int * 0.0078 * 0.0038 / 0.0040 + 0  # z_out = z_int * input_scale * weight_scale / output_scale + output_zp
z_out  # 30.514379999999996 round to 31

I have no idea where the mistake is, and need your help. Thank for any advice !

not sure, (what i had here previously was wrong)

we have unit tests that check this calculation,

i.e. https://github.com/pytorch/pytorch/blob/54217e695d71430cff79f04933d4b1384ba99610/test/quantization/core/test_quantized_op.py#L4420

with the main part being https://github.com/pytorch/pytorch/blob/54217e695d71430cff79f04933d4b1384ba99610/test/quantization/core/test_quantized_op.py#L4319

i.e. for a quantized_conv op with quantized weight W_q on a quantized X_q, this is equilvant to a fp32 conv op with a weight W_q.dequantize() on X_q.dequantize(). If you take the output of that and quantize it using the scale and zp from the quantized conv, you should get the same result (per this unit test)

@HDCharles, thank for your help ! It’s really helpful to me. However, I have another question.

Because I want to deploy my quantized model in a INT8 hardware. That is, I cannot use dequantized input and dequantized weight/bias.

Given quantized input, quantized conv weight/bias, I will do conv operation manually on my hardware:

q_x = 24
z = q_x * 127  # z = q_x * w
bias_q = round(0.0317/(0.0038*0.0078))  # bias_q = round(bias/(weight scale * input scale))
z_int = z + bias_q
z_out = z_int * 0.0078 * 0.0038 / 0.0040 + 0  # z_out = z_int * input_scale * weight_scale / output_scale + output_zp

If I cannot get the correct result, I am afraid that my quantized model will output bad result on INT8 hardware.

Could you give me some advice to conquer this problem ? Thank you !

to put it bluntly, my team (quantization team) doesn’t do anything with the actual quantized kernels. Thats the domain of the fbgemm or qnnpack team that actually developed those kernels. They are more complicated than they look. So i’m not going to be able to point to the specific part of your calculation that’s different as I only have a relatively high level understanding of whats going on.

Having said that, I can provide you with some tools that may make your life easier. you are trying to compare some specific calculation to the quantized kernel which hides a significant amount of complexity. If it were me, i’d compare instead to the reference implementation, and from there i’d be able to determine whether its the quant, the conv or the requant that is causing the divergence.

You could also look at the reference qlinear op which actually uses only elementary operations (i.e. the reference qconv op using conv is not ideal)

also there’s this in depth document that goes into some of the specifics of quantized matmul and its justifications (though not ones that are specific to fbgemm): gemmlowp/quantization.md at master · google/gemmlowp · GitHub which i think is largely what you are trying to replicate. I believe it mimics the qlienar reference implementation above.

You can also see the code for the actual quantize operation that you’re trying to duplicate here:

note line 72 which i don’t think is the issue, but is certainly not something that you’re capturing i.e. round-to-even on ties.

It’s just because you use the number printed on screen as your scale parameter, but not scale parameter itself which will cause numerical error w.r.t. scale parameter, hence the difference between the original result and result you calculate.
Please try the following code and have a look whether you can correct your result.

qx = activations[1]['input'][0].int_repr()
wx = net.conv1.weight().int_repr()

sinput = activations[1]['input'][0].q_scale()
sweight = activations[1]['module'].weight().q_per_channel_scales()[0]
soutput = activations[1]['module'].scale
zinput = activations[1]['input'][0].q_zero_point()
zweight = activations[1]['module'].weight().q_per_channel_zero_points()[0]
zoutput = activations[1]['module'].zero_point

bias = activations[1]['module'].bias()
qbias = torch.round(bias / (sinput * sweight))

qoutput = qx * wx + qbias
qoutput = torch.round(qoutput * sinput * sweight / soutput + zoutput)
qoutput = torch.clamp(qoutput, 0, 127)
print((activations[1]['output'].int_repr() == qoutput).sum())

And the result in my implementation is

>>>tensor(64)

which means our experiment results match exactly with the original results.

1 Like