I would like to execute a PyTorch model trained with quantization-aware training (QAT) as a fully quantized model. However, the output of my fully quantized and fake quantized models do not match.
To highlight the problem, I defined a very simple experiment consisting of quantizing only a single fused Conv-ReLU operation with hard-coded weights and quantization parameters. What I noticed is that;
- Torch produces the expected result for a model prepared with
torch.quantization.prepare_qat
(fake quantized) - Torch produces unexpected results when the previously prepared model is converted to a fully quantized model with
torch.quantization.convert
(real quantized)
I am wary that I might have an error in my implementation, so I provide a detailed example below and a repo of the example here;
Experiment
Compare fake and real quantized model outputs;
- inferring with a normal QAT model (fake quantized) - produces expected results
- inferring with a prepared and converted model to int8 (quantized) - produces unexpected results
To highlight the issue, I have set up a simple toy example as follows;
Model
A simple Conv-ReLU fused model is defined with
- bias set to zero
- conv weights set to
k*I
wherek
is some floating point scalar multiplier andI
represents an identity matrix of the correct shape for the conv layer weights - A quantization stub which quantizes the
fp32
inputs toquint8
- A dequantization stub which dequantizes the
quint8
outputs tofp32
- note, this stub gets set to the identity for the fully int8 quantized model
class FooConv1x1(nn.Module):
def __init__(self, set_qconfig):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1, 1) # 1x1 Conv kernel
self.act = nn.ReLU()
self.quant = torch.quantization.QuantStub(CustomQConfigs.get_default_qconfig())
self.dequant = torch.quantization.DeQuantStub()
self.modules_to_fuse = [['conv', 'act']]
if set_qconfig:
self.set_qconfig()
def forward(self, x):
x = self.quant(x)
output_quant = self.act(self.conv(x))
return self.dequant(output_quant)
def fuse(self):
torch.ao.quantization.fuse_modules(self, self.modules_to_fuse, inplace=True)
return self
def set_qconfig(self):
self.qconfig = CustomQConfigs.get_default_qconfig()
return self
def set_weights(self, multiplier):
# Set bias to zero and conv weights to k*Identity
self.conv.bias = torch.nn.Parameter(torch.zeros_like(self.conv.bias))
self.conv.weight = torch.nn.Parameter(multiplier * torch.eye(3).reshape(self.conv.weight.shape))
PyTorch QConfig
The quantization config for the model was defined with;
- Per tensor affine quantization everywhere except for the Conv layer’s weights which are per tensor symmetric
class CustomQConfigs:
@staticmethod
def get_default_qconfig():
return torch.quantization.QConfig(activation=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False),
weight=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric))
Inputs
Inputs are provided to the model in single precision floating point units in all cases. To highlight the issue, I consider passing a range of input values between 0 and 255 across an input image of size [1,3,256,256]
and scaling the values to between 0 and 1 by dividing by 255.
Setup and execution
The example can be executed using the following snippets or by cloning the repo and following the instructions here.
Note, I use torch.quantization.prepare_qat
instead of torch.quantization.prepare
so that observers get added the Conv layer so that I can hard code the quantization parameters to their analytically calculated values.
import torch
backend = "fbgemm"
# backend = "qnnpack"
torch.backends.quantized.engine = backend
torch.manual_seed(0)
# Hard code relevant quantization parameters
def set_qconfig_params(model_prepared, k):
# Conv weight
model_prepared.conv.weight_fake_quant.scale = torch.Tensor([2.0*k/255.0]) # Symmetric, hence multiply by 2
model_prepared.conv.weight_fake_quant.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.conv.weight_fake_quant.activation_post_process.max_val = torch.tensor(k)
# Requantization
model_prepared.conv.activation_post_process.scale = torch.Tensor([k/255.0])
model_prepared.conv.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.conv.activation_post_process.max_val = torch.tensor(k)
model_prepared.conv.activation_post_process.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.conv.activation_post_process.activation_post_process.max_val = torch.tensor(k)
# Input quant stub
model_prepared.quant.activation_post_process.scale = torch.Tensor([1.0/255.0])
model_prepared.quant.activation_post_process.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.quant.activation_post_process.activation_post_process.max_val = torch.tensor(1.0)
if __name__ == "__main__":
input_fp32 = torch.arange(0,256).repeat(1,3,256,1)/255.0 # 0 to 255 repeated across rows, then normalized to [0,1]
model = FooConv1x1(set_qconfig=True) # Prepare model with QConfig defined
k = 1.0 # Set Conv layer multiplier
model.set_weights(k) # Set bias to zero and conv weights to k*Identity
model.fuse() # fuse conv and ReLU
model_prepared = torch.quantization.prepare_qat(model).train() # prepare_qat required to set weight qparams
model_prepared.eval()
model_prepared.apply(torch.quantization.disable_observer).eval() # Disable QConfig Observers
set_qconfig_params(model_prepared, k) # Set quantization parameters to theoretical values
expected_output_fp32 = model_prepared(input_fp32)
expected_output_quint8 = (expected_output_fp32*(k*255)).to(torch.uint8)
model_prepared.dequant = torch.nn.Identity() # Disable the output dequant stub
# Convert model so that it runs as fully quantized model
model_quant = torch.quantization.convert(model_prepared, inplace=False)
output_quint8_fp32 = model_quant(input_fp32) # fp32 outputs with scale and shift parameterising it to quint8
error = torch.abs(expected_output_fp32 - output_quint8_fp32.dequantize())
error_mean = torch.mean(error)
error_max = torch.max(error)
first_nonzero_index = error.nonzero()[0].tolist()
print(f"{error_mean=}")
print(f"{error_max=}")
print(f"First nonzero: index: ({first_nonzero_index}")
print(f"\tvalue fp32: {error[*first_nonzero_index]}")
print(f"\tvalue expected quint8: {expected_output_quint8[*first_nonzero_index]}")
print(f"\tvalue outputed quint8: {output_quint8_fp32.int_repr()[*first_nonzero_index]}")
# import ipdb; ipdb.set_trace()
Dependencies
The example in this repo was tested using
- Python 3.11.0
- Python packages installed with pip which are listed in the
requirements.txt
in the repo provided
Operating system:
- Ubuntu 22.04.1 LTS
Hardware architecture:
- x86_64 - Intel
Observations
For simplicity, I compare the quantized outputs, but the same can be observed for the dequantized outputs.
- For the output of a model (fake or real quantized), I expect each row to be identical across all rows and channels. This was observed in all cases indicating determinism within a model’s execution
- The quantized outputs of the (fake quantized) model prepared with
torch.quantization.prepare_qat
were as expected;- values ranging from 0 to 255, indicating a unique bin for each of the outputs which get dequantized into the expected output value
- a summary of the first row of the first channel depicts the beginning, middle and end of that row
tensor([ 0, 1, ..., 126, 127, 128, 129, 130, 131, ..., 253, 254, 255], dtype=torch.uint8)
- The quantized outputs of the (quantized) model converted with
torch.quantization.convert
were not quite as expected;- values ranging from 0 to 254, implying we are losing information somewhere within the quantized execution
- a summary of the first row of the first channel depicts the beginning, middle and end of that row
tensor([ 0, 1, ..., 126, 127, 127, 128, 129, 130, ..., 252, 253, 254], dtype=torch.uint8)
- Comparing the quantized outputs of the two models, we observe in the real quantized model;
- a discrepancy can be seen as the value 127 is repeated twice, and all other values are shifted after that
- this results in the repeated 127 value being the incorrect bin for its expected dequantized value, and all other values following this duplication have been shifted to incorrect bins as well
- this behaviour is unexpected and results in non-determinism across both model’s executions
- note, it is interesting that this discrepancy appears for
128=255/2+1
and all following values due to 128 being halway bin of the possible range of bins
Conclusion
One of the main reasons for using quantization is to ensure determinism across different compute platforms, so the non-deterministic behaviour between a fake and real quantized model is extremely problematic, especially when it comes to deploying quantized models.
It is clear from this example that the real quantized model is not working as expected. This must either be due to an error I have in my implementation or a bug within PyTorch. Any help determining the issue would be really appreciated, and if you think it is a bug in PyTorch, I will happily create a bug report, but I wanted to check here first.