Dear pytorch forum, I find that result of quantized conv2d is different from what I calculate.
First, I import necessary package:
import torch
import torch.nn as nn
And the build a simple convolution model:
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.dequant(x)
return x
net = CustomModel()
net.eval()
Configuration for quantize model:
my_qconfig = torch.quantization.qconfig.QConfig(
activation=torch.quantization.observer.HistogramObserver.with_args(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True),
weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False )
)
net.qconfig = my_qconfig
torch.backends.quantized.engine = "fbgemm"
torch.quantization.prepare(net, inplace=True)
Calibrate model with random generated data:
calibrate_data = torch.randint(low=0, high=255, size=(1, 4, 16), dtype=torch.uint8).unsqueeze(0)
calibrate_data = calibrate_data / 255
_ = net(calibrate_data)
Quantize model:
torch.quantization.convert(net, inplace=True)
Output of quantized model:
CustomModel(
(quant): Quantize(scale=tensor([0.0078]), zero_point=tensor([0]), dtype=torch.quint8)
(conv1): QuantizedConv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.004018464591354132, zero_point=0)
(dequant): DeQuantize()
)
I can know that:
- Input Scale = 0.0078
- Input Zero Point = 0
- Output Scale of Quantized Conv2d = 0.0040
- Output Zero Point of Quantized Conv2d = 0
I create a sample input:
channel1 = torch.arange(0, 64).view(4, 16).to(torch.uint8).unsqueeze(0)
input_data = channel1.unsqueeze(0)
input_data = input_data / 255
Register forward hook to all modules because I want to save activation in model:
activations = []
def custom_hook(module, input, output):
info = {
'module': module,
'input': input,
'output': output
}
activations.append(info)
for name, module in net.named_modules():
if len(list(module.children())) == 0:
module.register_forward_hook(custom_hook)
Feed input data to model:
_ = net(input_data)
Extract the input and output of quantized conv2d:
activations[1]['input'][0].int_repr()
activations[1]['output'].int_repr()
# input
tensor([[[[ 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8],
[ 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16],
[16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24],
[24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32]]]],
dtype=torch.uint8)
# output
tensor([[[[ 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 13, 13, 14, 14, 15],
[15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23],
[23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 28, 28, 29, 29, 30],
[30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38]]]],
dtype=torch.uint8)
I find that if the input value is 24, the output will be 30.
I check the bias, weight of quantized conv2d and its integer representation:
net.conv1.weight()
net.conv1.weight().int_repr()
net.conv1.bias()
# weight
tensor([[[[0.4808]]]], size=(1, 1, 1, 1), dtype=torch.qint8,
quantization_scheme=torch.per_channel_affine,
scale=tensor([0.0038], dtype=torch.float64), zero_point=tensor([0]),
axis=0)
# int repr
tensor([[[[127]]]], dtype=torch.int8)
# bias
tensor([0.0317], requires_grad=True)
I can know that:
- Weight Scale = 0.0038
- Weight Zero Point = 0
- Weight = 127
- Bias (before quantized) = 0.0317
After I get all these information, I try to reproduce the output (30) from input (24), but it fails:
(calculation formula is from here provided by @jerryzh168)
q_x = 24
z = q_x * 127 # z = q_x * w
bias_q = round(0.0317/(0.0038*0.0078)) # bias_q = round(bias/(weight scale * input scale))
z_int = z + bias_q
z_out = z_int * 0.0078 * 0.0038 / 0.0040 + 0 # z_out = z_int * input_scale * weight_scale / output_scale + output_zp
z_out # 30.514379999999996 round to 31
I have no idea where the mistake is, and need your help. Thank for any advice !