Sorry if this question has been answered before. Although I’ve found several similar topics here, I still cannot produce a fully-quantized model. Before diving into the code, let’s define what “fully-quantized” means: all tensors in the model (input & output, weights, activations, and biases) are quantized to integer, and the computations are performed in integer arithmetic.
I’m using PyTorch v1.13.0 for the code snippet below which leveraged Input image with int? and Is bias quantized while doing pytorch static quantization?.
import torch
from torch import nn
import copy
# Define a simple model
model = nn.Sequential(
nn.Conv2d(2,64,3),
nn.ReLU(),
nn.Conv2d(64, 128, 3),
nn.ReLU()
)
# Configuring for full quantization using FX graph mode
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
m = copy.deepcopy(model)
m.eval()
# Running on a x86 CPU
backend = "fbgemm"
qconfig = torch.quantization.get_default_qconfig(backend)
qconfig_mapping = QConfigMapping().set_global(qconfig)
# Quantizing the inputs and outputs
prepare_custom_config = PrepareCustomConfig()
prepare_custom_config.set_input_quantized_indexes([0]) # to quantize the input tensor
prepare_custom_config.set_output_quantized_indexes([0]) # to quantize the output tensor
# Quantizing biases to int32
from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
weighted_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.qint32)
conv_config = BackendPatternConfig(torch.nn.Conv2d).add_dtype_config(weighted_int8_dtype_config)
backend_config = BackendConfig("full_int_backend").set_backend_pattern_config(conv_config)
# prepare and convert the model
x = torch.randint(0, 255, size=[1,2,28,28])
model_prepared = quantize_fx.prepare_fx(m, qconfig_mapping, x, prepare_custom_config, backend_config=backend_config)
# Calibrate - Use representative (validation) data.
with torch.inference_mode():
for _ in range(10):
x = torch.randint(0, 255, size=[1,2,28,28])
model_prepared(x)
# Quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
However, the bias type in Conv2D still seems to be float, as the log below shows. Please scroll to the bottom.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/var/folders/cm/8dfdl1jx59s_qthfgrk14l040000gr/T/ipykernel_35569/1227764410.py in <module>
40 for _ in range(10):
41 x = torch.randint(0, 255, size=[1,2,28,28])
---> 42 model_prepared(x)
43 # quantize
44 model_quantized = quantize_fx.convert_fx(model_prepared)
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/fx/graph_module.py in call_wrapped(self, *args, **kwargs)
656
657 def call_wrapped(self, *args, **kwargs):
--> 658 return self._wrapped_call(self, *args, **kwargs)
659
660 cls.__call__ = call_wrapped
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/fx/graph_module.py in __call__(self, obj, *args, **kwargs)
275 raise e.with_traceback(None)
276 else:
--> 277 raise e
278
279 @compatibility(is_backward_compatible=True)
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/fx/graph_module.py in __call__(self, obj, *args, **kwargs)
265 return self.cls_call(obj, *args, **kwargs)
266 else:
--> 267 return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
268 except Exception as e:
269 assert e.__traceback__
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
<eval_with_key>.8 in forward(self, input)
4 def forward(self, input):
5 input_1 = input
----> 6 _0 = getattr(self, "0")(input_1); input_1 = None
7 activation_post_process_0 = self.activation_post_process_0(_0); _0 = None
8 _1 = getattr(self, "1")(activation_post_process_0); activation_post_process_0 = None
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
461
462 def forward(self, input: Tensor) -> Tensor:
--> 463 return self._conv_forward(input, self.weight, self.bias)
464
465 class Conv3d(_ConvNd):
~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
458 _pair(0), self.dilation, self.groups)
459 return F.conv2d(input, weight, bias, self.stride,
--> 460 self.padding, self.dilation, self.groups)
461
462 def forward(self, input: Tensor) -> Tensor:
RuntimeError: Input type (long long) and bias type (float) should be the same