How is quantization of activations handled in pytorch after QAT?

I am compiling a quantized pytorch model with TVM and using ReLu6 for activation of the conv layers but the output of the model changes dramatically.

TVM quantizes the value of “6” using input scale and input zero-point that come with the PyTorch model. In one case, the input scale is: 0.019743409007787704, and the input zero-point is 0.

So the quantized value of “6” is computed as:

6/0.019743409007787704 + 0

which is equal to: 303.89xxx

This value exceeds the max value that can be represented by a uint8, which is 255.

How does Pytorch handle quantization for activations such as hardtanh if there is no fusion? Does it quantize the min and max values base on input’s scale and zero point?

Note that I do not want to perform fusion as TVM does something similar when it optimizes the model.

It is not clear to me how the content leading up to the questions in OP relate to the questions that were posed. Can you clarify?

How does Pytorch handle quantization for activations such as hardtanh if there is no fusion? Does it quantize the min and max values base on input’s scale and zero point?

I believe the activation has its own quantization parameters and the quantization should be with respect to these parameters. Between Conv and the activation layer, there will be a quant and dequant operation.

Btw, currently, we only support fusion for relu. hardtanh is not supported with fusion (see bottom of Fuse Modules Recipe — PyTorch Tutorials 1.10.1+cu102 documentation)

1 Like

in pytorch quantization you don’t quantize to uint8, you quantize to quint8. Although its stored (in part) like a uint8, thats not the value it represents. As an analogy, consider how everything in the PC is just 1’s and 0’s, but for fp32 data, the value that those 1’s and 0’s represent is a decimal.

Similarly, the 6 is the int value (6 < 255 so its within the range for the internal representation) but it represents another value based on the scale and zero point.

here is a good explanation on the math behind quantization

1 Like

Thank you for your explanation David. So scale and zero-point of the activation is completely independent of the input it is given when it is being called right? I assume these values are calculated during QAT and stored somewhere for inference.

I save my model after QAT like below:

quantized_gen = torch.quantization.convert(gen, inplace=False)
torch.save(quantized_gen.state_dict(), gen_checkpoint)

And later I load the saved state_dict dictionary and I dont see any values corresponding to scale and zero point for activations.
Here is what state_dict shows:

‘encoder.features.0.0.weight’,
‘encoder.features.0.0.bias’,
’encoder.features.0.0.scale’, tensor(0.2315)
’encoder.features.0.0.zero_point’, tensor(0)
‘encoder.features.1.conv.0.0.weight’,
‘encoder.features.1.conv.0.0.bias’,
’encoder.features.1.conv.0.0.scale’, tensor(0.3612)
’encoder.features.1.conv.0.0.zero_point’, tensor(0)
‘encoder.features.1.conv.1.weight’,
‘encoder.features.1.conv.1.bias’,
’encoder.features.1.conv.1.scale’, tensor(0.3612)
’encoder.features.1.conv.1.zero_point’, tensor(65)
‘encoder.features.1.skip_add.scale’, tensor(1.)
‘encoder.features.1.skip_add.zero_point’, tensor(0)

How do I find the scale and zero point for activation?

Those are the scale and zero point for activations, the weight is already a quantized tensor (feel free to print it) the bias is unquantized in our kernels.

The model is generally quantized with a QuantStub at the start that becoems a quantize_per_tensor/channel (which has scale/zero_point that should show up in the state dict). From here the activations flow into a series of quantized ops.

the quantized op kernel takes (1) the qtensor of the activation (the qparams for this are intrinsic to the qtensor so no need to pass them in seperately), (2) the packed params (the quantized weight and the unquantized bias), and (3) the scale and zero point for the output (which is what you’re seeing there). So each quantized op is fully defined by the kernel along with (2) and (3) in the state dict.

1 Like

I think drawing it out would make my question more clear. Is the following diagram correct?

My question is what happens to 0 and 6 in Hardtan(0, 6). What scale and zero values are used to quantize those?

nothing. there’s no quantized hardtanh op because its an elementwise operation whose range is determined by the input and it works equally as well on quantized and nonquantized tensors.

>>> ht=torch.nn.modules.activation.Hardtanh()
>>> x=torch.randn(3,3)
>>> xq=torch.quantize_per_tensor(b, 1.0, 0, torch.quint8)
>>> ht(xq)
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.]], size=(3, 3), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

its essentially the same as torch.clamp(X, -1,1)
you don’t need output qparams for that.

I assume you are using fx quantization since hardtanh doesn’t even show up in the eager_mode quantization mappings so it wouldn’t recieve an observer there.

edit to answer your specific question about whether the diagram is correct, are you asking about how it works in qat or once converted? for the converted model, the output is never an fp32, the scale/zp are inputs to the quantized kernel and the output is already quantized. Also nothing is ever int8, its qint8 or quint8. In qat nothing is in int8/qint8/quint8, everything happens in fp32 but goes through fakequant ops that simulate the conversion but leave the value in fp32. In that case the output does come out as an fp32 and then goes into another fakequant with a scale and zeropoint specified.

1 Like

Makes sense. Thank you.

we’ll quantize 0 and 6 with the same quantization parameters as input, here is the implementation: pytorch/QuantizedOpKernels.cpp at master · pytorch/pytorch · GitHub