How pytorch simulates bias during quantization aware training

It seems that pytorch qat doesn’t simulate bias quantization error during qat. And I found that qat.Conv2d only fake-quantize weight and activation. So pytorch’s quantization strategy does not quantize the bias, right?

yes, we do not quantize bias. there have been some internal discussions on this before, the problem of quantizing bias is that it needs to be quantized with the quantization parameters of input and weight, but the input can come from dynamic paths e.g.:

if x > 0:
    y = myConv1(x)
else:
    y = myConv2(x)
  
z = myConv3(y)

and we have no way of getting this information in eager mode. currently we pass in bias in fp32 and it will be quantized inside the quantized ops like quantized::conv2d with quantization parameters of input and weight: y = conv(x_q,w_q) + bias/(w_scale*x_scale).

However, for qat, I think currently we do not simulate this behavior, I’m not sure how much impact this has though, we’ll discuss about it, thanks for the question.

1 Like

So, if I want to transfer the quantization aware trained network to my hardware, how exactly should i implement the bias part?
should I use the above formula to quantize it?

right now the quantization for bias is not modeled in quantization aware training, so there might be a little bit of discrepancy between the qat model and the model after convert, but I think it won’t matter too much.

Thank you for the response, Jerry.
So, what should I do with the bias parameter of the batch-norm module when I want to implement my quantized model on hardware? the final converted model (quantized) still has this parameter (in FP) in the quantized version of ConvBnReLU2d.

  • Would bias be totally ignored when we recall the quantized model for some input X (model.eval() )?
  • or the intermediate feature values are temporarily converted to FP to apply bias to them and then are converted back to INT8/INT32?
  • or bias is also converted to INT8 with a simple choice of sale or zero-point without the influence of the qat part?

bias is an input to quantized::conv2d op, it is applied in quantized::conv2d op itself, with this formula:

this is in int32. then we’ll requantize y with output_scale and output_zero_point
cc @dskhudia could you link the fbgemm implementation for conv?

1 Like

We find modeling bias in qat is not very important since it doesn’t affect accuracy too much. one workaround you can do is to remove bias from Conv and add the bias explicitly outside of conv, so that adding bias can be modeled with add.

1 Like

Thanks for your reply, modeling bias with add op sounds good!

hello can you explain more how to remove bias from conv ?

just set bias to None in conv, but add an additional add after conv