yes,you are right. I changed the function convert_to_reference_fx to _convert_to_reference_decomposed_fx, the result is as expected(the reduce_range flag works).
code:
import torch
import torch.nn.functional as F
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.backend_config.fbgemm import get_fbgemm_backend_config
from torch.ao.quantization.qconfig import get_default_qat_qconfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx
from torch.ao.quantization.quantize_fx import convert_fx,convert_to_reference_fx,_convert_to_reference_decomposed_fx
import os
import copy
def quantize_per_tensor_uint8(x_fp32, scale, zero_point, quant_min, quant_max):
x = x_fp32 / scale # fp32
x = torch.round(x) # fp32
x = x.to(dtype=torch.int32) # int32
x = x + zero_point # int32
x = torch.clamp(x, quant_min, quant_max) # int32
x = x.to(dtype=torch.uint8)
return x
def dequantize_per_tensor_uint8(x_i8, scale, zero_point):
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
class Debug(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_debug=torch.nn.Conv2d(in_channels=3, out_channels=5, kernel_size=(3,3), padding=0, stride=1, groups=1,bias=True)
def forward(self,x):
x=self.conv_debug(x)
return x
if __name__=='__main__':
# torch.manual_seed(4)
print('The default quantized engine is {}'.format(torch.backends.quantized.engine))
Q_MAX_LIST=[127,255]
for Q_MAX in Q_MAX_LIST:
for i in range(10):
print('.....................When Q_MAX is {}, the result(loop {}):.................................'.format(Q_MAX,i))
net=Debug()
backend_config=get_fbgemm_backend_config()
qconfig=get_default_qat_qconfig('fbgemm')
qconfig_mapping = QConfigMapping().set_global(qconfig)
net.train()
net_prepare=prepare_qat_fx(net,qconfig_mapping,torch.randn(1,3,10,10),backend_config=backend_config)
net_prepare(torch.randn(1,3,10,10))
net_prepare(torch.randn(1,3,10,10))
net_converted=convert_fx(copy.deepcopy(net_prepare),qconfig_mapping=qconfig_mapping,backend_config=backend_config)
net_converted_ref=_convert_to_reference_decomposed_fx(copy.deepcopy(net_prepare),qconfig_mapping=qconfig_mapping,backend_config=backend_config)
input=torch.randn(1,3,10,10)
result_quant_ref=net_converted_ref(input)
net_converted_state_dict=net_converted.state_dict()
##############input scale and zero_point####################################
scale_quant_input=net_converted_state_dict['conv_debug_input_scale_0']
zero_point_quant_input=net_converted_state_dict['conv_debug_input_zero_point_0']
##############Conv2d scale and zero_point, int weight and float bias####################################
weight_conv_debug_float=net_converted_state_dict['conv_debug.weight']
weight_conv_debug_float_dequantize=weight_conv_debug_float.dequantize()
scale_conv_debug=weight_conv_debug_float.q_per_channel_scales()
zero_point_debug_int=weight_conv_debug_float.q_per_channel_zero_points()
bias_conv_debug_float=net_converted_state_dict['conv_debug.bias']
##############output scale and zero_point####################################
scale_quant_output=net_converted_state_dict['conv_debug.scale']
zero_point_quant_output=net_converted_state_dict['conv_debug.zero_point']
####################simulate the process of quantization################################
####################z_scale(z_quant-z_zeropoint)=x_scale(x_quant-x_zeropoint)*y_scale(y_quant-y_zeropoint),y_zeropoint=0
#get x_quant
input_quant=quantize_per_tensor_uint8(
x_fp32=input,
scale=scale_quant_input,
zero_point=zero_point_quant_input,
quant_min=0,
quant_max=Q_MAX)
input_quant_dequant=dequantize_per_tensor_uint8(
x_i8=input_quant,
scale=scale_quant_input,
zero_point=zero_point_quant_input)
quantize_conv2d_reference = F.conv2d(
input_quant_dequant,
weight_conv_debug_float_dequantize,
bias_conv_debug_float,
1,0, 1, 1)
output_quant=quantize_per_tensor_uint8(
x_fp32=quantize_conv2d_reference,
scale=scale_quant_output,
zero_point=zero_point_quant_output,
quant_min=0,
quant_max=Q_MAX)
output_dequant=dequantize_per_tensor_uint8(
x_i8=output_quant,
scale=scale_quant_output,
zero_point=zero_point_quant_output)
#check the result with original quantized model
close_flag=torch.allclose(result_quant_ref,output_dequant.to(torch.float32))
print('close_flag={}'.format(close_flag))
diff=result_quant_ref-output_dequant
if not close_flag:
count_big_diff=(torch.abs(diff)>0.0000001).sum()
diff_shape=diff.shape
count_total=diff_shape[0]*diff_shape[1]*diff_shape[2]*diff_shape[3]
print('non algin ratio:{:.2%}'.format(count_big_diff/count_total))
result:
The default quantized engine is x86
.....................When Q_MAX is 127, the result(loop 0):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 1):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 2):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 3):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 4):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 5):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 6):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 7):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 8):.................................
close_flag=True
.....................When Q_MAX is 127, the result(loop 9):.................................
close_flag=True
.....................When Q_MAX is 255, the result(loop 0):.................................
close_flag=False
non algin ratio:0.63%
.....................When Q_MAX is 255, the result(loop 1):.................................
close_flag=False
non algin ratio:0.31%
.....................When Q_MAX is 255, the result(loop 2):.................................
close_flag=True
.....................When Q_MAX is 255, the result(loop 3):.................................
close_flag=False
non algin ratio:27.19%
.....................When Q_MAX is 255, the result(loop 4):.................................
close_flag=True
.....................When Q_MAX is 255, the result(loop 5):.................................
close_flag=True
.....................When Q_MAX is 255, the result(loop 6):.................................
close_flag=False
non algin ratio:1.56%
.....................When Q_MAX is 255, the result(loop 7):.................................
close_flag=False
non algin ratio:2.50%
.....................When Q_MAX is 255, the result(loop 8):.................................
close_flag=True
.....................When Q_MAX is 255, the result(loop 9):.................................
close_flag=False
non algin ratio:0.31%
I also print the converted model produced by 3 converters and compare them:
_convert_to_reference_decomposed_fx:
GraphModule(
(conv_debug): QuantizedConv2d(Reference)(3, 5, kernel_size=(3, 3), stride=(1, 1))
)
def forward(self, x):
conv_debug_input_scale_0 = self.conv_debug_input_scale_0
conv_debug_input_zero_point_0 = self.conv_debug_input_zero_point_0
quantize_per_tensor = torch.ops.quantized_decomposed.quantize_per_tensor(x, conv_debug_input_scale_0, conv_debug_input_zero_point_0, 0, 127, torch.uint8); x = None
dequantize_per_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor(quantize_per_tensor, conv_debug_input_scale_0, conv_debug_input_zero_point_0, 0, 127, torch.uint8); quantize_per_tensor = conv_debug_input_scale_0 = conv_debug_input_zero_point_0 = None
conv_debug = self.conv_debug(dequantize_per_tensor); dequantize_per_tensor = None
conv_debug_scale_0 = self.conv_debug_scale_0
conv_debug_zero_point_0 = self.conv_debug_zero_point_0
quantize_per_tensor_1 = torch.ops.quantized_decomposed.quantize_per_tensor(conv_debug, conv_debug_scale_0, conv_debug_zero_point_0, 0, 127, torch.uint8); conv_debug = None
dequantize_per_tensor_1 = torch.ops.quantized_decomposed.dequantize_per_tensor(quantize_per_tensor_1, conv_debug_scale_0, conv_debug_zero_point_0, 0, 127, torch.uint8); quantize_per_tensor_1 = conv_debug_scale_0 = conv_debug_zero_point_0 = None
return dequantize_per_tensor_1
convert_to_reference_fx:
GraphModule(
(conv_debug): QuantizedConv2d(Reference)(3, 5, kernel_size=(3, 3), stride=(1, 1))
)
def forward(self, x):
conv_debug_input_scale_0 = self.conv_debug_input_scale_0
conv_debug_input_zero_point_0 = self.conv_debug_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, conv_debug_input_scale_0, conv_debug_input_zero_point_0, torch.quint8); x = conv_debug_input_scale_0 = conv_debug_input_zero_point_0 = None
dequantize = quantize_per_tensor.dequantize(); quantize_per_tensor = None
conv_debug = self.conv_debug(dequantize); dequantize = None
conv_debug_scale_0 = self.conv_debug_scale_0
conv_debug_zero_point_0 = self.conv_debug_zero_point_0
quantize_per_tensor_1 = torch.quantize_per_tensor(conv_debug, conv_debug_scale_0, conv_debug_zero_point_0, torch.quint8); conv_debug = conv_debug_scale_0 = conv_debug_zero_point_0 = None
dequantize_1 = quantize_per_tensor_1.dequantize(); quantize_per_tensor_1 = None
return dequantize_1
convert_fx:
GraphModule(
(conv_debug): QuantizedConv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), scale=0.02617141790688038, zero_point=65)
)
def forward(self, x):
conv_debug_input_scale_0 = self.conv_debug_input_scale_0
conv_debug_input_zero_point_0 = self.conv_debug_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, conv_debug_input_scale_0, conv_debug_input_zero_point_0, torch.quint8); x = conv_debug_input_scale_0 = conv_debug_input_zero_point_0 = None
conv_debug = self.conv_debug(quantize_per_tensor); quantize_per_tensor = None
dequantize_1 = conv_debug.dequantize(); conv_debug = None
return dequantize_1
The quantize function in the model produced by convert_fx and convert_to_reference_fx is torch.quantize_per_tensor, its signature is :
def quantize_per_tensor(input: Tensor, scale: Tensor, zero_point: Tensor, dtype: _dtype) -> Tensor: ...
But the quantize function in model produced by _convert_to_reference_decomposed_fx is torch.ops.quantized_decomposed.quantize_per_tensor, its signature is:
quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> Tensor
The function has extra inputs quant_min and quant_max compared to the above function. and this is the essential difference.
Is it incorrect if we use fbgemm backend following the tutorial in Quantization — PyTorch 1.13 documentation when we use convert_fx to convert the QAT model?