How to apply per_tensor_symmetric activation quantization?

I wish to perform quantization with a configuration in which both parameters and activations are quantized symmetrically. Here is my code:

rn18 = models.resnet18().eval()
data = torch.randn(1, 3, 224, 224)
qconfig = torch.ao.quantization.QConfig(
    activation=torch.ao.quantization.observer.HistogramObserver.with_args(
        qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
    ),
    weight=torch.quantization.default_per_channel_weight_observer,
)
qconfig_mapping = torch.ao.quantization.QConfigMapping().set_global(qconfig)
prepared = prepare_fx(rn18, qconfig_mapping, data)
for _ in range(10):
    prepared(data)
quantized_rn18 = convert_fx(prepared)
quantized_rn18.graph.print_tabular()

The printed result shows that the model is not quantized.
How to solve the problem?

You could use torch.ao.quantization.default_symmetric_qnnpack_qconfig as defined in pytorch/qconfig.py at b8580b08976db89203f2ea7dda0f012520e9471a · pytorch/pytorch · GitHub

rn18 = torchvision.models.resnet18().eval()
data = torch.randn(1, 3, 224, 224)
qconfig = torch.ao.quantization.default_symmetric_qnnpack_qconfig
qconfig_mapping = torch.ao.quantization.QConfigMapping().set_global(qconfig)
prepared = prepare_fx(rn18, qconfig_mapping, data)
for _ in range(10):
    prepared(data)
quantized_rn18 = convert_fx(prepared)
quantized_rn18.graph.print_tabular()

I use the default_symmetric_qnnpack_qconfig, but the printed result shows that model is still not quantized.

I have the same issue. Is there any workaround?

can you post a repro?

Hi @HDCharles,

Thanks for your reply. It’s basically same as @Wenlong_Shi’s flow. You can reproduce the issue with below although my model is not resnet:

import os
import torchvision
import torch
from torch.ao.quantization import get_default_qconfig_mapping, default_symmetric_qnnpack_qconfig, get_default_qconfig
from torch.ao.quantization import QConfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

# global_qconfig = get_default_qconfig("qnnpack") # this works
global_qconfig = default_symmetric_qnnpack_qconfig # this doesn't work
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
qconfig_mapping.set_global(global_qconfig)

model = torchvision.models.resnet18().eval()
data = torch.randn(1, 3, 224, 224)
model_prepared = prepare_fx(model, qconfig_mapping, data)

for _ in range(10):
    model_prepared(data)

model_quantized = convert_fx(model_prepared)

def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

print_model_size(model) # 46.83 MB
print_model_size(model_quantized) # 46.75 MB

The printed result shows that model is not quantized.

Thank you for your help.

To add a little bit of background, what I would like to do is to match the sign of weights and activation (e.g. qint8/qint8). The qint8/quint8 combination such as get_default_qconfig(“qnnpack”) quantizes a model as expected, but when I set the sign to match, the quantization doesn’t work.

Hi @HDCharles,

Have you had a chance to check my code? I’d like to fix this issue and would appreciate your help.

Thanks.