Hello @jerryzh168,
I have tried multiple configurations and the warning still persists, this is my code
import torch
import torch.nn as nn
import numpy
print(torch.__version__)
import torch.ao.nn.quantized as nnquantized
import torch.ao.nn.quantizable as nnquantizable
import torch.ao.quantization.observer as observer
from torch.ao.quantization.backend_config._common_operator_config_utils import _add_fixed_qparams_to_dtype_configs, _FIXED_QPARAMS_OP_TO_CONSTRAINTS
import torch.ao.quantization.qconfig as qconfig
from torch.ao.quantization.backend_config import qnnpack, DTypeConfig, get_qnnpack_backend_config
from torch.ao.quantization.fx.custom_config import (
PrepareCustomConfig,
ConvertCustomConfig,
QConfigMapping,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qconfig
from torch.ao.quantization.fx import lstm_utils
from collections import namedtuple
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_lstm = torch.nn.LSTM(50, 50, 1)
def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
x = self.my_lstm(inputs, (h0, c0))
return x
# Construct a BackendConfig that supports qint32 for certain ops
# TODO: build a BackendConfig from scratch instead of modifying an existing one
qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
my_backend_config = get_qnnpack_backend_config()
for config in my_backend_config.configs:
if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
config.add_dtype_config(qint32_dtype_config)
my_backend_config = get_qnnpack_backend_config()
class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
"""
Example of user provided LSTM implementation that assigns fixed qparams
to the inner ops.
"""
@classmethod
def from_float(cls, float_lstm):
assert isinstance(m.my_lstm, cls._FLOAT_MODULE)
# uint16, [-16, 16)
linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=2**-11, zero_point=2**15, dtype=torch.qint32
)
# uint16, [0, 1)
sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
)
# uint16, [-1, 1)
tanh_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
)
# int16, [-16, 16)
cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=0.11, zero_point=0, dtype=torch.qint32
)
# uint8, [-1, 1)
hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=2**-7, zero_point=2**7, dtype=torch.qint32
)
example_inputs = (torch.rand(1,50), (torch.rand(1, 50), torch.randn(1, 50)))
return lstm_utils._get_lstm_with_individually_observed_parts(
float_lstm=float_lstm,
example_inputs=example_inputs,
backend_config=my_backend_config,
linear_output_obs_ctr=linear_output_obs_ctr,
sigmoid_obs_ctr=sigmoid_obs_ctr,
tanh_obs_ctr=tanh_obs_ctr,
cell_state_obs_ctr=cell_state_obs_ctr,
hidden_state_obs_ctr=hidden_state_obs_ctr,
)
class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
"""
Example of user provided LSTM implementation that produces a reference
quantized module from a `UserObservedLSTM`.
"""
@classmethod
def from_observed(cls, observed_lstm):
assert isinstance(observed_lstm, cls._FLOAT_MODULE)
return lstm_utils._get_reference_quantized_lstm_module(
observed_lstm=observed_lstm,
backend_config=my_backend_config,
)
# FX graph mode quantization
m = MyModel().eval()
#conf_linear = qconfig.QConfig( activation= observer.FixedQParamsObserver.with_args( scale=2**-11, zero_point=2**15, dtype=torch.qint32) )
#conf_sigmoid = qconfig.QConfig( activation= observer.FixedQParamsObserver.with_args( scale=2**-16, zero_point=0, dtype=torch.qint32 ) )
#conf_tanh = qconfig.QConfig(activation= observer.FixedQParamsObserver.with_args( observer.FixedQParamsObserver.with_args( scale=2**-15, zero_point=2**15, dtype=torch.qint32 ) ) )
#qconfig_mapping = QConfigMapping().set_module_name(torch.nn.Linear,conf_linear ) \
# .set_module_name(torch.nn.Sigmoid, conf_sigmoid) \
# .set_module_name(torch.nn.Tanh, conf_tanh)
#sig_config = qconfig.QConfig(activation= observer.FixedQParamsObserver(scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8).with_args(x=1),
# weight= observer.FixedQParamsObserver(scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8).with_args(dtype=torch.quint8).with_args(x=1))
#qconfig_mapping = QConfigMapping().set_global(get_default_qconfig_mapping("qnnpack")) \
# .set_module_name("sigmoid_obs_ctr", sig_config)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
example_inputs = (torch.rand(1,1, 50), torch.rand(1,1, 50), torch.randn(1, 1, 50))
prepare_custom_config = PrepareCustomConfig().set_float_to_observed_mapping(
torch.nn.LSTM, UserObservedLSTM
)
convert_custom_config = ConvertCustomConfig().set_observed_to_quantized_mapping(
torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM
)
prepared = prepare_fx(
m,
qconfig_mapping,
example_inputs,
prepare_custom_config,
backend_config=my_backend_config,
)
prepared.print_readable()
I even changed the qparams with the one in here pytorch/torch/ao/quantization/backend_config/_common_operator_config_utils.py at main · pytorch/pytorch · GitHub for sigmoid and tanh operations but still the quantization is not done properly.
Output:
/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.10/site-packages/torch/ao/quantization/fx/utils.py:829: UserWarning: QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize for fixed qparams ops, ignoring QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f042ee47f40>}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f042ee47f40>}).
Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
model = prepare_fx(model, qconfig_mapping, example_inputs)
warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
class GraphModule(torch.nn.Module):
def forward(self, inputs : torch.Tensor, h0 : torch.Tensor, c0 : torch.Tensor):
# No stacktrace found for following nodes
activation_post_process_0 = self.activation_post_process_0(inputs); inputs = None
activation_post_process_1 = self.activation_post_process_1(h0); h0 = None
activation_post_process_2 = self.activation_post_process_2(c0); c0 = None
# File: /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/custom_lstm_models/custom_LSTM_PyTorchAPI_example.py:32, code: x = self.my_lstm(inputs, (h0, c0))
my_lstm = self.my_lstm(activation_post_process_0, (activation_post_process_1, activation_post_process_2)); activation_post_process_0 = activation_post_process_1 = activation_post_process_2 = None
# No stacktrace found for following nodes
getitem = my_lstm[0]
dequant_stub_0 = self.dequant_stub_0(getitem); getitem = None
getitem_1 = my_lstm[1]; my_lstm = None
getitem_2 = getitem_1[0]
dequant_stub_1 = self.dequant_stub_1(getitem_2); getitem_2 = None
getitem_3 = getitem_1[1]; getitem_1 = None
dequant_stub_2 = self.dequant_stub_2(getitem_3); getitem_3 = None
tuple_1 = tuple([dequant_stub_1, dequant_stub_2]); dequant_stub_1 = dequant_stub_2 = None
tuple_2 = tuple([dequant_stub_0, tuple_1]); dequant_stub_0 = tuple_1 = None
return tuple_2
There is no quant_per_tensot() operation so I assume that quantization is skipped because of the heterogeneity of the configs.
Can you find a way to make this work, please? Or is there documentation I can read?
I really appreciate your help!
Thank you,