I could run the following code to quantize ResNet18.
The accuracy is Acc@1 83.444 Acc@5 96.090 when it is not quantized(a.k.a float32).
The accuracy is Acc@1 82.606 Acc@5 95.846 when it is quantized.
My questions are:
where could I check the config file about the quantized datatype(fake int8 or fake int16, or fake int4)?
is bias quantized or not? where could I check the code? I did notice that bias is not quantized in some posts. however, I would like to confirm it. My further question is: if bias is not quantized, how could I run the inference in real int8/int16 hardware?
def main():
args = prepare_params()
model = eval('hubconf.{}(pretrained=True)'.format('resnet18'))
model.eval()
m=model
data_path = r'D:\datasets\imagenet'
train_loader, val_loader = load_data(path=data_path,batch_size=64)
# cali_data = load_calibrate_data(train_loader, cali_batchsize=16)
logging.info('---- Accuracy befor PTQ ----')
with torch.no_grad():
evaluate(val_loader, m, device="cuda", print_freq=20)
# Step 1. program capture
# example_inputs = torch.randn((3,224,224))
example_inputs = (torch.randn(1, 3, 224, 224),)
# example_inputs = (example_inputs, )
# NOTE: this API will be updated to torch.export API in the future,
# but the captured result should mostly stay the same
m = capture_pre_autograd_graph(m, example_inputs)
# we get a model with aten ops
# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)
# run calibration
calibrate(m, train_loader)
m = convert_pt2e(m)
with torch.no_grad():
evaluate(val_loader, m, device="cuda", print_freq=20)
# Step 3. lowering
# lower to target backend
return
@Gurkirt,
Thank you for your reply!
in my code in this post, if you check the function get_symmetric_quantization_config in line quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()), you will find out that both weight and activation are quantized to int8. however, I didn’t find out the
quantization of bias in the code.
There is one more problem about bias quantization:
if we don’t quantize bias, we are not able to run int8 or int 16 inference on hardware.
if we quantize bias, the output of conv2d is not able to be dequantized. this is very serious problem. Let me show a piece of code as follows to explain the reason:
# code 1: for fake quantization
x = self.conv_module(x) # fake int8 or fake int 16
# as bias is added to multiplication, the output x is float
x = F.relu(x)
# code 2: for int8 or int16 hardware inference
x = self.conv_module(x) # int8 or int 16
# as bias is added to multiplication, we are not able to use scale to dequantize the output of conv_module.
x = F.relu(x)
@Ardeal Yes, you are write bias is not explicitly quantized. But output is quantized after addition of bias.
Taking snippet from @jerryzh168 in above thead.
z = qconv(wq, xq)
# z is at scale (weight_scale*input_scale) and at int32
# because mulication of int8*int8 = int16 then for accumulation int16*kernel_size*kernel_size * input_channels int32 should be enough
bias_q = round(bias/(input_scale*weight_scale))
# Convert float32 biad to int32
# this can be done before hand, eventhough bias is kept float during quantization but you can store it in int32 vector
# in practice, your compute kernel can keep everything upt here in
z_int = z + bias_q
# perform 32 bit add
z_out = round[(z_int)*(input_scale*weight_scale)/output_scale) - z_zero_point]
# rounding to 8 bits;
# this can also be achieved by bit shift operations if your are careful so that you don't have to store float for input_scale*weight_scale/output_scale
z_out[z_out<0] = 0 # ReLU
z_out = saturate(z_out)
I think this is more or less right but I could be wrong.
As you can see, I am also struggling to get this completely right
the snippet code in your above post is very good. z = qconv(wq, xq) do quantized conv2d, but z_int = z + bias_q add the bias after conv2d.
it means that qconv function doesn’t add bias. This is the only solution I previous thank of.
For PTQ, the python code should statistics scale, zero_point, max and min using calibration data. however, I didn’t see python to do the quantization of bias. Did you see the code?
one more problem is: if we quantize bias using adding bias after conv2d, how about the accuracy of neural network?
It is should exactly be the same what you get from pytorch as current PyTorch quantization is just a wrapper around backend kernels (x86, xnn, onednn, cudnn), because at runtime (I assume) bias is quantized by the respective backend kernel. Weight quantziation does not require calibration and same apply to bias, calibration is require to calculate input and output scale and zero points.
Correct. I have been searching it for many days, and I didn’t see any open source code(including pytorch, or some other open source code related with pytorch) which quantized bias.
that is your code.
I checked your code which seems not the answer I am searching. Let me re-explain my question with the following pseudo-code.
there are 2 questions in my pseudo-code:
I am not sure whether solution 2 is the best solution, and whether there is some other solution which performs better.
where can I find the code which insert the functionality to pytorch PTQ and can be used directly?
scale_weight = (vmax - vmin)/(qmax - qmin)
scale_x = (vmax - vmin)/(qmax - qmin)
scale_bias = (vmax - vmin)/(qmax - qmin)
scale_qo = (vmax - vmin)/(qmax - qmin)
qw = weight/scale_weight # round and clamp to int8
qb = bias/scale_bias # round and clamp to int8
qx = x/scale_x # round and clamp to int8
# solution 1: bad/wrong solution:
qo = conv2d(qx) # qo = qw*qx + qb, this is fake int8 or int8 calculation
# the qo may exceed the range of int8, and needs to be de-quantized and re-quantized
de_qo = qo*what_scale? # as bias is added to the output of mul, there is no scale to de-quantize `qo`
# solution 2: good solution:
qo = conv2d(qx, bias=False) # qo = qw*qx, don't add bias in conv2d function
de_qo = qo*scale_weight*scale_x # de_quantize qo
de_bias = qb *scale_bias # dequantize bias
de_qoo = de_qo + de_bias # add them in fake float type
re_de_qo = de_qoo/scale_qo # requantize to fake int8 or int 8
PyTorch does not quantize bias, so if you are looking for PyTorch solution then you are out of luck.
What you can is use input_scale, and kernel scale to quantize bias then use in your kernel which you can write on your device. Like I have done in above code.
I am going to modify an open source code which quantized weight and x but didn’t quantize bias to insert my above pseudo-code to add bias quantization.