Qnnpack using activation dtype int8 is not runnable

hi, @jerryzh168, when i convert a model using

qconfig=default_symmetric_qnnpack_qat_qconfig

the jit model can not run using libtorch2.01(c++) on win10 vs2019

but if I use the following qconfig

qconfig=get_default_qat_qconfig(backend='qnnpack', version=1)

it will work.

the code i used is below:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv_test=torch.nn.Conv2d(1,3,(3,3))
    def forward(self,x):
        x=self.conv_test(x)
        return x
input=torch.randn(1,1,15,15)
model=Model()

import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_to_quantize=copy.deepcopy(model)
model_to_quantize.train()

from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.qconfig import default_symmetric_qnnpack_qat_qconfig,get_default_qat_qconfig
qconfig=get_default_qat_qconfig(backend='qnnpack', version=1)
# qconfig=default_symmetric_qnnpack_qat_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config=torch.ao.quantization.backend_config.qnnpack.get_qnnpack_backend_config()
torch.backends.quantized.engine = 'qnnpack'

model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize,qconfig_mapping=qconfig_mapping,example_inputs=input,backend_config=backend_config)
model_prepared(input)
model_prepared(input)
model_quantized = quantize_fx.convert_fx(model_prepared,qconfig_mapping=qconfig_mapping,backend_config=backend_config)

model_trace=torch.jit.trace(model_quantized,input)
torch.jit.save(model_trace,'conv_layer.pth')

pytorch version : 2.0.1+cu117

what is the error message?


the source code i use is:

this looks like a lint error? does it actually results in real errors?

It’s actually a runtime error.
I have also run the jit model using C++ on Linux_x86_64 platform, the error produced by the following qconfig is more clear:

qconfig=default_symmetric_qnnpack_qat_qconfig

error message:

(torch2.01) /libTorchTest/build# ./example-app 
loading model success
thread number of the machine is:48
terminate called after throwing an instance of 'std::runtime_error'
  what():  The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torch/fx/graph_module.py", line 15, in forward
    conv_test_input_scale_0 = self.conv_test_input_scale_0
    input = torch.quantize_per_tensor(x, conv_test_input_scale_0, conv_test_input_zero_point_0, 12)
    _0 = torch.dequantize((conv_test).forward(input, ))
                           ~~~~~~~~~~~~~~~~~~ <--- HERE
    return _0
  File "code/__torch__/torch/ao/nn/quantized/modules/conv.py", line 86, in forward
    input: Tensor) -> Tensor:
    _packed_params = self._packed_params
    conv_test = ops.quantized.conv2d(input, _packed_params, 0.017023416236042976, -16)
                ~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return conv_test

Traceback of TorchScript, original code (most recent call last):
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/_ops.py(502): __call__
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/ao/nn/quantized/modules/conv.py(469): forward
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/nn/modules/module.py(1501): _call_impl
<eval_with_key>.6(8): forward
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/fx/graph_module.py(271): __call__
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/fx/graph_module.py(662): call_wrapped
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/jit/_trace.py(1056): trace_module
/home/sfadmin/anaconda3/envs/torch2.01/lib/python3.9/site-packages/torch/jit/_trace.py(794): trace
quant_2.py(33): <module>
RuntimeError: quantized::conv2d (ONEDNN): data type of input should be QUint8.

Aborted (core dumped)

It seems that the qconfig of symmetric activation can not be executed on x86_64. but the python script works well on both qconfigs.

all the code I used is as follow:
python script:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv_test=torch.nn.Conv2d(1,3,(3,3))
    def forward(self,x):
        x=self.conv_test(x)
        return x
input=torch.randn(1,1,15,15)
model=Model()

import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_to_quantize=copy.deepcopy(model)
model_to_quantize.train()

from torch.ao.quantization.qconfig_mapping import QConfigMapping,_get_symmetric_qnnpack_qconfig_mapping
from torch.ao.quantization.qconfig import default_symmetric_qnnpack_qat_qconfig,get_default_qat_qconfig
# qconfig=get_default_qat_qconfig(backend='qnnpack', version=1)
qconfig=default_symmetric_qnnpack_qat_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config=torch.ao.quantization.backend_config.qnnpack.get_qnnpack_backend_config()
torch.backends.quantized.engine = 'qnnpack'

model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize,qconfig_mapping=qconfig_mapping,example_inputs=input,backend_config=backend_config)
model_prepared(input)
model_prepared(input)
model_quantized = quantize_fx.convert_fx(model_prepared,qconfig_mapping=qconfig_mapping,backend_config=backend_config)


model_trace=torch.jit.trace(model_quantized,input)
torch.jit.save(model_trace,'conv_layer_3.pth')

model_trace_reload=torch.jit.load('conv_layer_3.pth')

result_ref=model_quantized(input)
result_model_trace_reload=model_trace_reload(input)
print(torch.allclose(result_ref,result_model_trace_reload))

the c++ test code:

#include <iostream>
#include "torch/torch.h"
#include "torch/script.h"
int main(int argc, char* argv[])
{
    torch::jit::script::Module module;
    try{
        module=torch::jit::load("./conv_layer_3.pth");
    }
    catch(const c10::Error& e){
        std::cerr<<"error loading model:"<<e.what()<<std::endl;
        throw;
    }
    std::cout<<"loading model success"<<std::endl;
    std::cout<<"thread number of the machine is:"<<torch::get_num_threads()<<std::endl;
    torch::set_num_threads(1);
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::randn({1,1,15,15}));
    at::Tensor output=module.forward(inputs).toTensor();
    std::cout<<output<<std::endl;
    std::cout<<"calc finished"<<std::endl;
}

CMakeList.txt:

cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(example-app)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(example-app test.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 17)

looks like the error is: RuntimeError: quantized::conv2d (ONEDNN): data type of input should be QUint8.

“It seems that the qconfig of symmetric activation can not be executed on x86_64”, I think x86_64 is expecting activation dtype to be QUInt8, not QInt8 from the “get_qnnpack_backend_config”

are you trying to use qnnpack in jit runtime?

Yes, I am trying to use it in C++ backend on x86_64 platform.
I see that QInt8 input and output is supported by “get_qnnpack_backend_config” and can be executed in python script but failed in C++ environment.

in that case you should use fbgemm/x86 configs, not qnnpack configs I think: https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/qconfig.py#L313

Is there a future support on qint8 input/output on x86_64? it’s inefficient to test large data using qnnpack backend devides

I think you can ask @leslie-fang-intel @Jiong_Gong

For x86, we support quint8 with asymmetric quantization for activations (weight is qint8 with symmetric per-tensor/per-channel). Are you able to use this configuration instead?

From the x86.py, I know that it’s activation support dtype is quint8 for conv2d on x86 backend. Is there a future support of qint8 activation for x86_64? QNNPACK has the qint8 activation config, but the exported model is not usable on my windows environment.

hello , I also meet a error " RuntimeError: quantized::conv2d (ONEDNN): data type of input should be QUint8. " when i try to export a S8S8(activation qint8) quantize model to onnx, if there are any solution for this?

may be this will help to you jit model of qnnpack's int8 activation can not be executed on x86_64 using libtorch but python script works well · Issue #106598 · pytorch/pytorch · GitHub

i modyfied my code by adding:

q_backend = "qnnpack"  # qnnpack  or fbgemm
torch.backends.quantized.engine = q_backend

and i got:
quantized::linear (xnnpack): Unsupported config for dtype KQInt8
it seems change backend is make sense , but how do i make the config support?
my config is

   qconfig = torch.ao.quantization.qconfig.QConfig(
    activation=torch.ao.quantization.observer.HistogramObserver.with_args(
        qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
    weight=torch.ao.quantization.observer.default_per_channel_weight_observer
)
    qconfig_mapping = (QConfigMapping()
        .set_global(qconfig)  
       
    )

it seems that pytorch version below 2.0 does not support qint8 input. you can update your pytorch to the latest version and have a try following the link above.

thankyou for reply, i have relove it by change my backend but qnnpack does not currently support per-channel fully connected op… :rofl:it is hard to convert a fx quantize model to onnx