I would like to execute a PyTorch model trained with quantization-aware training (QAT) as a fully quantized model. However, the output of my fully quantized and fake quantized models do not match.

To highlight the problem, I defined a very simple experiment consisting of quantizing only a single fused Conv-ReLU operation with hard-coded weights and quantization parameters. What I noticed is that;

- Torch produces the expected result for a model prepared with
`torch.quantization.prepare_qat`

(fake quantized) - Torch produces unexpected results when the previously prepared model is converted to a fully quantized model with
`torch.quantization.convert`

(real quantized)

I am wary that I might have an error in my implementation, so I provide a detailed example below and a repo of the example here;

## Experiment

Compare fake and real quantized model outputs;

- inferring with a normal QAT model (fake quantized) - produces expected results
- inferring with a prepared and converted model to int8 (quantized) - produces unexpected results

To highlight the issue, I have set up a simple toy example as follows;

### Model

A simple Conv-ReLU fused model is defined with

- bias set to zero
- conv weights set to
`k*I`

where`k`

is some floating point scalar multiplier and`I`

represents an identity matrix of the correct shape for the conv layer weights - A quantization stub which quantizes the
`fp32`

inputs to`quint8`

- A dequantization stub which dequantizes the
`quint8`

outputs to`fp32`

- note, this stub gets set to the identity for the fully int8 quantized model

```
class FooConv1x1(nn.Module):
def __init__(self, set_qconfig):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1, 1) # 1x1 Conv kernel
self.act = nn.ReLU()
self.quant = torch.quantization.QuantStub(CustomQConfigs.get_default_qconfig())
self.dequant = torch.quantization.DeQuantStub()
self.modules_to_fuse = [['conv', 'act']]
if set_qconfig:
self.set_qconfig()
def forward(self, x):
x = self.quant(x)
output_quant = self.act(self.conv(x))
return self.dequant(output_quant)
def fuse(self):
torch.ao.quantization.fuse_modules(self, self.modules_to_fuse, inplace=True)
return self
def set_qconfig(self):
self.qconfig = CustomQConfigs.get_default_qconfig()
return self
def set_weights(self, multiplier):
# Set bias to zero and conv weights to k*Identity
self.conv.bias = torch.nn.Parameter(torch.zeros_like(self.conv.bias))
self.conv.weight = torch.nn.Parameter(multiplier * torch.eye(3).reshape(self.conv.weight.shape))
```

### PyTorch QConfig

The quantization config for the model was defined with;

- Per tensor affine quantization everywhere except for the Conv layer’s weights which are per tensor symmetric

```
class CustomQConfigs:
@staticmethod
def get_default_qconfig():
return torch.quantization.QConfig(activation=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False),
weight=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric))
```

### Inputs

Inputs are provided to the model in single precision floating point units in all cases. To highlight the issue, I consider passing a range of input values between 0 and 255 across an input image of size `[1,3,256,256]`

and scaling the values to between 0 and 1 by dividing by 255.

### Setup and execution

The example can be executed using the following snippets or by cloning the repo and following the instructions here.

Note, I use `torch.quantization.prepare_qat`

instead of `torch.quantization.prepare`

so that observers get added the Conv layer so that I can hard code the quantization parameters to their analytically calculated values.

```
import torch
backend = "fbgemm"
# backend = "qnnpack"
torch.backends.quantized.engine = backend
torch.manual_seed(0)
# Hard code relevant quantization parameters
def set_qconfig_params(model_prepared, k):
# Conv weight
model_prepared.conv.weight_fake_quant.scale = torch.Tensor([2.0*k/255.0]) # Symmetric, hence multiply by 2
model_prepared.conv.weight_fake_quant.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.conv.weight_fake_quant.activation_post_process.max_val = torch.tensor(k)
# Requantization
model_prepared.conv.activation_post_process.scale = torch.Tensor([k/255.0])
model_prepared.conv.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.conv.activation_post_process.max_val = torch.tensor(k)
model_prepared.conv.activation_post_process.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.conv.activation_post_process.activation_post_process.max_val = torch.tensor(k)
# Input quant stub
model_prepared.quant.activation_post_process.scale = torch.Tensor([1.0/255.0])
model_prepared.quant.activation_post_process.activation_post_process.min_val = torch.tensor(0.0)
model_prepared.quant.activation_post_process.activation_post_process.max_val = torch.tensor(1.0)
if __name__ == "__main__":
input_fp32 = torch.arange(0,256).repeat(1,3,256,1)/255.0 # 0 to 255 repeated across rows, then normalized to [0,1]
model = FooConv1x1(set_qconfig=True) # Prepare model with QConfig defined
k = 1.0 # Set Conv layer multiplier
model.set_weights(k) # Set bias to zero and conv weights to k*Identity
model.fuse() # fuse conv and ReLU
model_prepared = torch.quantization.prepare_qat(model).train() # prepare_qat required to set weight qparams
model_prepared.eval()
model_prepared.apply(torch.quantization.disable_observer).eval() # Disable QConfig Observers
set_qconfig_params(model_prepared, k) # Set quantization parameters to theoretical values
expected_output_fp32 = model_prepared(input_fp32)
expected_output_quint8 = (expected_output_fp32*(k*255)).to(torch.uint8)
model_prepared.dequant = torch.nn.Identity() # Disable the output dequant stub
# Convert model so that it runs as fully quantized model
model_quant = torch.quantization.convert(model_prepared, inplace=False)
output_quint8_fp32 = model_quant(input_fp32) # fp32 outputs with scale and shift parameterising it to quint8
error = torch.abs(expected_output_fp32 - output_quint8_fp32.dequantize())
error_mean = torch.mean(error)
error_max = torch.max(error)
first_nonzero_index = error.nonzero()[0].tolist()
print(f"{error_mean=}")
print(f"{error_max=}")
print(f"First nonzero: index: ({first_nonzero_index}")
print(f"\tvalue fp32: {error[*first_nonzero_index]}")
print(f"\tvalue expected quint8: {expected_output_quint8[*first_nonzero_index]}")
print(f"\tvalue outputed quint8: {output_quint8_fp32.int_repr()[*first_nonzero_index]}")
# import ipdb; ipdb.set_trace()
```

## Dependencies

The example in this repo was tested using

- Python 3.11.0
- Python packages installed with pip which are listed in the
`requirements.txt`

in the repo provided

Operating system:

- Ubuntu 22.04.1 LTS

Hardware architecture:

- x86_64 - Intel

## Observations

For simplicity, I compare the quantized outputs, but the same can be observed for the dequantized outputs.

- For the output of a model (fake or real quantized), I expect each row to be identical across all rows and channels. This was observed in all cases indicating determinism within a model’s execution
- The quantized outputs of the (fake quantized) model prepared with
`torch.quantization.prepare_qat`

were as expected;- values ranging from 0 to 255, indicating a unique bin for each of the outputs which get dequantized into the expected output value
- a summary of the first row of the first channel depicts the beginning, middle and end of that row

```
tensor([ 0, 1, ..., 126, 127, 128, 129, 130, 131, ..., 253, 254, 255], dtype=torch.uint8)
```

- The quantized outputs of the (quantized) model converted with
`torch.quantization.convert`

were not quite as expected;- values ranging from 0 to 254, implying we are losing information somewhere within the quantized execution
- a summary of the first row of the first channel depicts the beginning, middle and end of that row

```
tensor([ 0, 1, ..., 126, 127, 127, 128, 129, 130, ..., 252, 253, 254], dtype=torch.uint8)
```

- Comparing the quantized outputs of the two models, we observe in the real quantized model;
- a discrepancy can be seen as the value 127 is repeated twice, and all other values are shifted after that
- this results in the repeated 127 value being the incorrect bin for its expected dequantized value, and all other values following this duplication have been shifted to incorrect bins as well
- this behaviour is unexpected and results in non-determinism across both model’s executions
- note, it is interesting that this discrepancy appears for
`128=255/2+1`

and all following values due to 128 being halway bin of the possible range of bins

## Conclusion

One of the main reasons for using quantization is to ensure determinism across different compute platforms, so the non-deterministic behaviour between a fake and real quantized model is extremely problematic, especially when it comes to deploying quantized models.

It is clear from this example that the real quantized model is not working as expected. This must either be due to an error I have in my implementation or a bug within PyTorch. Any help determining the issue would be really appreciated, and if you think it is a bug in PyTorch, I will happily create a bug report, but I wanted to check here first.