Manual quantization result comparison reveals mismatch

I created a simple model with a qunatization and dequantization block

class M_only_quant_dequant(nn.Module):

    def __init__(self):
            super(M_only_quant_dequant, self).__init__()
            # QuantStub converts tensors from floating point to quantized
            self.quant = torch.quantization.QuantStub()            
            self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.dequant(x)         
        return x

I passed a floating point input and calibrated the scales and zero point

input_fp32 = torch.tensor([[[[ 0.95466703176498413086, -0.136212718486785888672,
            0.75253891944885253906,  1.57104063034057617188],
          [ 0.97250884771347045898, -0.67004448175430297852,
           -0.58047348260879516602,  1.30683445930480957031],
          [-0.13423979282379150391,  0.16391958296298980713,
           -0.71688455343246459961,  0.05846109613776206970],
          [ 1.07569837570190429688, -0.06351475417613983154,
           -0.19469638168811798096, -0.09430617839097976685]]]], requires_grad=False)
model_quant_only = M_only_quant_dequant()
model_quant_only.eval()
model_quant_only.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_quant_only_1 = torch.quantization.prepare(model_quant_only)
model_quant_only_1(input_fp32)   ### Passing the input through model before conversion for calibration
model_quant_only_converted_1 = torch.quantization.convert(model_quant_only_1, inplace=True)

The output of model came out to be the following. There is some difference between input and output because of quantization

tensor([[[[ 0.950229287147521972656250000000, -0.135747045278549194335937500000,
            0.746608734130859375000000000000,  1.578059434890747070312500000000],
          [ 0.967197716236114501953125000000, -0.576924920082092285156250000000,
           -0.576924920082092285156250000000,  1.306565284729003906250000000000],
          [-0.135747045278549194335937500000,  0.169683814048767089843750000000,
           -0.576924920082092285156250000000,  0.050905141979455947875976562500],
          [ 1.069007992744445800781250000000, -0.067873522639274597167968750000,
           -0.186652183532714843750000000000, -0.101810283958911895751953125000]]]])

I manually computed the quantized tensor and performed a dequantization using the following equations

input_quant_manual = torch.round(input_fp32.detach()/model_quant_only_converted_1.quant.scale)+model_quant_only_converted_1.quant.zero_point
input_dequant_manual = (input_quant_manual - model_quant_only_converted_1.quant.zero_point)*model_quant_only_converted_1.quant.scale
input_dequant_manual =
tensor([[[[ 0.950229287147521972656250000000, -0.135747045278549194335937500000,
            0.746608734130859375000000000000,  1.578059434890747070312500000000],
          [ 0.967197716236114501953125000000, -0.661766827106475830078125000000,
           -0.576924920082092285156250000000,  1.306565284729003906250000000000],
          [-0.135747045278549194335937500000,  0.169683814048767089843750000000,
           -0.712671995162963867187500000000,  0.050905141979455947875976562500],
          [ 1.069007992744445800781250000000, -0.067873522639274597167968750000,
           -0.186652183532714843750000000000, -0.101810283958911895751953125000]]]])

However there is a difference between the two values only while others are matching

input_dequant_manual - model_quant_only_converted_1(input_fp32) = 

tensor([[[[ 0.000000000000000000000000000000,  0.000000000000000000000000000000,
            0.000000000000000000000000000000,  0.000000000000000000000000000000],
          [ 0.000000000000000000000000000000, -0.084841907024383544921875000000,
            0.000000000000000000000000000000,  0.000000000000000000000000000000],
          [ 0.000000000000000000000000000000,  0.000000000000000000000000000000,
           -0.135747075080871582031250000000,  0.000000000000000000000000000000],
          [ 0.000000000000000000000000000000,  0.000000000000000000000000000000,
            0.000000000000000000000000000000,  0.000000000000000000000000000000]]]])

When I compared the integer representation of the pytorch Quantstub and the manual one I see that the negative values are rounded off to 0 by the QuantStub()

input_quant_manual
tensor([[[[ 90.,  26.,  78., 127.],
          [ 91.,  -5.,   0., 111.],
          [ 26.,  44.,  -8.,  37.],
          [ 97.,  30.,  23.,  28.]]]])
model_quant_only_converted_1.quant(input_fp32).int_repr()
tensor([[[[ 90,  26,  78, 127],
          [ 91,   0,   0, 111],
          [ 26,  44,   0,  37],
          [ 97,  30,  23,  28]]]], dtype=torch.uint8)

Because of this when I am manually dequantizing there is some difference. But why is the QuantStub setting the negative numbers in integer representation to zero?

1 Like

Hi Avishek,

I believe that this is because QuantStub is quantizing to uint8, which is an unsigned int of 8 bytes.
Your manual quantize step is missing a clamp from qmin to qmax (0 and 127 respectively) which I think is the cause of the discrepancy.

manual_quant = torch.clamp(torch.round(input_fp32 / quant.scale)+quant.zero_point, qmin=0, qmax=127)
1 Like

Hi Jesse,

Thanks for your response. Yes I figured that if I set the values to 0 for negative values then the results wold match. But I would like to know the reason why uint8 is being used and not int8 so that the whole of -127 to +127 can be covered.

Also do you know where to find how are the scale and zero_point applied in the pytorch code? I looked here but didn’t manage to locate it pytorch/torch/ao/quantization at master · pytorch/pytorch · GitHub

You can see (and change) the quantization target dtype in the qconfig.

print(torch.quantization.get_default_qconfig('fbgemm'))

On my machine this returns:

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){},
        weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

This will cause the quantization step to use the default values for HistogramObserver, which is quint8.
You can change the qconfig to use qint8 like this:

custom_qconfig = torch.quantization.get_default_qconfig('qnnpack')

new_qconfig = namedtuple("activation", "weight")
new_qconfig.activation = partial(custom_qconfig.activation, dtype=torch.qint8)
new_qconfig.weight = custom_qconfig.weight

The quantization step happens in the C++ codebase here.

1 Like