How to generate a fully-quantized model?

Sorry if this question has been answered before. Although I’ve found several similar topics here, I still cannot produce a fully-quantized model. Before diving into the code, let’s define what “fully-quantized” means: all tensors in the model (input & output, weights, activations, and biases) are quantized to integer, and the computations are performed in integer arithmetic.

I’m using PyTorch v1.13.0 for the code snippet below which leveraged Input image with int? and Is bias quantized while doing pytorch static quantization?.

import torch
from torch import nn
import copy

# Define a simple model
model = nn.Sequential(
     nn.Conv2d(2,64,3),
     nn.ReLU(),
     nn.Conv2d(64, 128, 3),
     nn.ReLU()
)
# Configuring for full quantization using FX graph mode
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig

m = copy.deepcopy(model)
m.eval()

# Running on a x86 CPU
backend = "fbgemm"  
qconfig = torch.quantization.get_default_qconfig(backend)
qconfig_mapping = QConfigMapping().set_global(qconfig)
# Quantizing the inputs and outputs
prepare_custom_config = PrepareCustomConfig() 
prepare_custom_config.set_input_quantized_indexes([0])  # to quantize the input tensor
prepare_custom_config.set_output_quantized_indexes([0]) # to quantize the output tensor

# Quantizing biases to int32
from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType

weighted_int8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.qint8,
    bias_dtype=torch.qint32)

conv_config = BackendPatternConfig(torch.nn.Conv2d).add_dtype_config(weighted_int8_dtype_config) 
backend_config = BackendConfig("full_int_backend").set_backend_pattern_config(conv_config)
# prepare and convert the model
x = torch.randint(0, 255, size=[1,2,28,28])

model_prepared = quantize_fx.prepare_fx(m, qconfig_mapping, x, prepare_custom_config, backend_config=backend_config)

# Calibrate - Use representative (validation) data.
with torch.inference_mode():
  for _ in range(10):
    x = torch.randint(0, 255, size=[1,2,28,28])
    model_prepared(x)
# Quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

However, the bias type in Conv2D still seems to be float, as the log below shows. Please scroll to the bottom.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/var/folders/cm/8dfdl1jx59s_qthfgrk14l040000gr/T/ipykernel_35569/1227764410.py in <module>
     40   for _ in range(10):
     41     x = torch.randint(0, 255, size=[1,2,28,28])
---> 42     model_prepared(x)
     43 # quantize
     44 model_quantized = quantize_fx.convert_fx(model_prepared)

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/fx/graph_module.py in call_wrapped(self, *args, **kwargs)
    656 
    657         def call_wrapped(self, *args, **kwargs):
--> 658             return self._wrapped_call(self, *args, **kwargs)
    659 
    660         cls.__call__ = call_wrapped

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/fx/graph_module.py in __call__(self, obj, *args, **kwargs)
    275                 raise e.with_traceback(None)
    276             else:
--> 277                 raise e
    278 
    279 @compatibility(is_backward_compatible=True)

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/fx/graph_module.py in __call__(self, obj, *args, **kwargs)
    265                 return self.cls_call(obj, *args, **kwargs)
    266             else:
--> 267                 return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
    268         except Exception as e:
    269             assert e.__traceback__

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

<eval_with_key>.8 in forward(self, input)
      4 def forward(self, input):
      5     input_1 = input
----> 6     _0 = getattr(self, "0")(input_1);  input_1 = None
      7     activation_post_process_0 = self.activation_post_process_0(_0);  _0 = None
      8     _1 = getattr(self, "1")(activation_post_process_0);  activation_post_process_0 = None

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

~/opt/anaconda3/envs/tvm-build-11/lib/python3.7/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    458                             _pair(0), self.dilation, self.groups)
    459         return F.conv2d(input, weight, bias, self.stride,
--> 460                         self.padding, self.dilation, self.groups)
    461 
    462     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Input type (long long) and bias type (float) should be the same

@Vasiliy_Kuznetsov @jerryzh168 It’d be great if you can provide some advice. TIA.

Hi @kgbounce, I believe this is expected behavior, since prepare_fx will attach quand-dequant pairs but still use the fp32 modules / ops. So you would need to calibrate with fp32 data.

Once the model has been converted I think you can start feeding it int inputs.

@jcaip thanks! I’ve previously followed the default settings in the PyTorch quantization tutorials which can execute convert_fx() without any issues. However, the converted and quantized model still has fp32 biases. Could you comment on the code in the original post that attempts to set dtype of biases to torch.qint32 vs. the default of `torch.fp32’? It doesn’t seem to work, as the error messages say the biases are still in fp32.

weighted_int8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.qint8,
    **bias_dtype=torch.qint32**) # HERE: trying to set bias to qint32, rather than fp32.

conv_config = BackendPatternConfig(torch.nn.Conv2d).add_dtype_config(weighted_int8_dtype_config) 
backend_config = BackendConfig("full_int_backend").set_backend_pattern_config(conv_config)

# prepare and convert the model
x = torch.randint(0, 255, size=[1,2,28,28])

model_prepared = quantize_fx.prepare_fx(m, qconfig_mapping, x, prepare_custom_config, backend_config=backend_config)

Then later the errors reported by convert_fx()

RuntimeError: Input type (long long) and **bias type (float)** should be the same

@jcaip I tried using fp32 data during calibration with the code below. Note that now example_input_int is generated with torch.randint() and x used for calibration is generated with torch.rand(). This code yield an error of

NotImplementedError: Could not run ‘quantized::conv2d_relu.new’ with arguments from the ‘CPU’ backend.

Although it’s listed as a common error here: Quantization — PyTorch master documentation. It’s not clear how to fix it.

## FX GRAPH
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig

m_int = copy.deepcopy(model)
m_int.eval()

backend_int = "x86"  # running on a x86 CPU. Use "qnnpack" if running on ARM.
qconfig_dict_int = {"": torch.quantization.get_default_qconfig(backend_int)}
prepare_custom_config = PrepareCustomConfig() 
prepare_custom_config.set_input_quantized_indexes([0])  
prepare_custom_config.set_output_quantized_indexes([0])
# Prepare
example_input_int = torch.randint(255, [1,2,28,28])
# example_input_int = torch.rand(1,2,28,28)
model_prepared_int = quantize_fx.prepare_fx(
    m_int, qconfig_dict_int,example_input_int, prepare_custom_config)

# Calibrate - Use representative (validation) data.
with torch.inference_mode():
  for _ in range(10):
    # x = torch.randint(255, [1,2,28,28])
    x = torch.rand(1,2,28,28)
    model_prepared_int(x)
# quantize
model_quantized_int = quantize_fx.convert_fx(model_prepared_int)

res_int = model_quantized_int(example_input_int)

print(res_int)

I tried using fp32 data during calibration with the code below. Note that now example_input_int is generated with torch.randint() and x used for calibration is generated with torch.rand()

this is because the converted quantized model still expects floating point data as input. The model would have a quantize operator to convert the input from fp32 to int8.

However, the converted and quantized model still has fp32 biases. Could you comment on the code in the original post that attempts to set dtype of biases to torch.qint32 vs. the default of `torch.fp32’? It doesn’t seem to work, as the error messages say the biases are still in fp32.

the default setting for biases in the quantized operators supported by PyTorch is fp32 I believe, if you wish to use custom dtypes you can follow the steps in https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md#use-case-1-quantizing-a-model-for-inference-on-servermobile to generate a reference quantized model and then lower that to your specific backend that supports integer bias.