Torch.ao.quantization.fx: Problem with custom LSTM quantization

Hello, I am trying to statically quantize an LSTM layer as mentioned here pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm .
I just copy paste the given example and I got this log error:

Traceback (most recent call last):
  File "/home/ahmed/Desktop/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 106, in <module>
    prepared = prepare_fx(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 380, in prepare_fx
    return _prepare_fx(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 142, in _prepare_fx
    prepared = prepare(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1632, in prepare
    result_node = insert_observers_for_model(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1368, in insert_observers_for_model
    _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1075, in _swap_custom_module_to_observed
    observed_custom_module_class.from_float(custom_module)
  File "/home/ahmed/Desktop/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 69, in from_float
    return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
AttributeError: module 'torch.ao.quantization.fx' has no attribute 'lstm_utils'. Did you mean: 'match_utils'?

I rebuild PyTorch and still the problem persists!
I am using 2.0.1+cu117 version!

Appreciate your help!

I cannot reproduce it in 2.0.1+cu117 and get:

python test_quantization.py TestQuantizeFx.test_static_lstm
/home/pbialecki/miniforge3/envs/2.0.1_cu117/lib/python3.10/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/home/pbialecki/miniforge3/envs/2.0.1_cu117/lib/python3.10/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
.
----------------------------------------------------------------------
Ran 1 test in 0.135s

OK

and using a current nightly 2.1.0.dev20230605+cu121:

python test_quantization.py TestQuantizeFx.test_static_lstm
WARNING:root:cannot import name 'ComposableQuantizer' from 'torch.ao.quantization._pt2e.quantizer' (/home/pbialecki/miniforge3/envs/nightly_pip_cu121/lib/python3.10/site-packages/torch/ao/quantization/_pt2e/quantizer/__init__.py)
/home/pbialecki/miniforge3/envs/nightly_pip_cu121/lib/python3.10/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/home/pbialecki/miniforge3/envs/nightly_pip_cu121/lib/python3.10/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
.
----------------------------------------------------------------------
Ran 1 test in 0.147s

OK

Thank you for your reply @ptrblck!
I fixed this issue by manually downloading lstm_utils.py ( I tried different versions and even rebuild pytorch but this file is always missing).
Besides, I got this warnings:

/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/utils.py:829: 
UserWarning: QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize for fixed qparams ops, ignoring QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f536a8881f0>}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f536a8881f0>}).
Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    model = prepare_fx(model, qconfig_mapping, example_inputs)
  warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "

Should I consider them ?
Thank you,

Are you seeing the warning by executing my posted command?
I’m unsure how concerning this warning would be, but @jerryzh168 would know.

yes you should follow the suggestion, this means the qconfig is not compatible with operators, here is the list of fixed qparams operators: pytorch/torch/ao/quantization/backend_config/_common_operator_config_utils.py at main · pytorch/pytorch · GitHub and the configurations in QConfigMapping should match this

1 Like

Understood, thank you!
This was a copy-paste of your example, maybe you should take a look at it!

Hello @jerryzh168,
I have tried multiple configurations and the warning still persists, this is my code

import torch
import torch.nn as nn
import numpy

print(torch.__version__)
import torch.ao.nn.quantized as nnquantized
import torch.ao.nn.quantizable as nnquantizable
import torch.ao.quantization.observer as observer
from torch.ao.quantization.backend_config._common_operator_config_utils import _add_fixed_qparams_to_dtype_configs, _FIXED_QPARAMS_OP_TO_CONSTRAINTS
import torch.ao.quantization.qconfig as qconfig

from torch.ao.quantization.backend_config import qnnpack, DTypeConfig, get_qnnpack_backend_config

from torch.ao.quantization.fx.custom_config import (
    PrepareCustomConfig,
    ConvertCustomConfig,
    QConfigMapping,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qconfig
from torch.ao.quantization.fx import lstm_utils
from collections import namedtuple


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.my_lstm = torch.nn.LSTM(50, 50, 1)

    def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
        x = self.my_lstm(inputs, (h0, c0))
        return x


# Construct a BackendConfig that supports qint32 for certain ops
# TODO: build a BackendConfig from scratch instead of modifying an existing one

qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
my_backend_config = get_qnnpack_backend_config()
for config in my_backend_config.configs:
    if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
        config.add_dtype_config(qint32_dtype_config)

my_backend_config = get_qnnpack_backend_config()

class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
    """
    Example of user provided LSTM implementation that assigns fixed qparams
    to the inner ops.
    """

    @classmethod
    def from_float(cls, float_lstm):
        assert isinstance(m.my_lstm, cls._FLOAT_MODULE)
        # uint16, [-16, 16)
        linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=2**-11, zero_point=2**15, dtype=torch.qint32
        )
        # uint16, [0, 1)
        sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
        )
        # uint16, [-1, 1)
        tanh_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
        )
        # int16, [-16, 16)
        cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=0.11, zero_point=0, dtype=torch.qint32
        )
        # uint8, [-1, 1)
        hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=2**-7, zero_point=2**7, dtype=torch.qint32
        )
        example_inputs = (torch.rand(1,50), (torch.rand(1, 50), torch.randn(1, 50)))
        return lstm_utils._get_lstm_with_individually_observed_parts(
            float_lstm=float_lstm,
            example_inputs=example_inputs,
            backend_config=my_backend_config,
            linear_output_obs_ctr=linear_output_obs_ctr,
            sigmoid_obs_ctr=sigmoid_obs_ctr,
            tanh_obs_ctr=tanh_obs_ctr,
            cell_state_obs_ctr=cell_state_obs_ctr,
            hidden_state_obs_ctr=hidden_state_obs_ctr,
        )


class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
    """
    Example of user provided LSTM implementation that produces a reference
    quantized module from a `UserObservedLSTM`.
    """

    @classmethod
    def from_observed(cls, observed_lstm):
        assert isinstance(observed_lstm, cls._FLOAT_MODULE)
        return lstm_utils._get_reference_quantized_lstm_module(
            observed_lstm=observed_lstm,
            backend_config=my_backend_config,
        )


# FX graph mode quantization
m = MyModel().eval()

#conf_linear = qconfig.QConfig( activation= observer.FixedQParamsObserver.with_args( scale=2**-11, zero_point=2**15, dtype=torch.qint32) ) 
#conf_sigmoid = qconfig.QConfig( activation= observer.FixedQParamsObserver.with_args( scale=2**-16, zero_point=0, dtype=torch.qint32 ) )
#conf_tanh = qconfig.QConfig(activation= observer.FixedQParamsObserver.with_args( observer.FixedQParamsObserver.with_args( scale=2**-15, zero_point=2**15, dtype=torch.qint32 ) ) ) 
                               


#qconfig_mapping = QConfigMapping().set_module_name(torch.nn.Linear,conf_linear ) \
#                                  .set_module_name(torch.nn.Sigmoid, conf_sigmoid) \
#                                  .set_module_name(torch.nn.Tanh, conf_tanh)


#sig_config = qconfig.QConfig(activation= observer.FixedQParamsObserver(scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8).with_args(x=1),
#                         weight= observer.FixedQParamsObserver(scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8).with_args(dtype=torch.quint8).with_args(x=1))

#qconfig_mapping = QConfigMapping().set_global(get_default_qconfig_mapping("qnnpack")) \
#                                  .set_module_name("sigmoid_obs_ctr", sig_config)

qconfig_mapping = get_default_qconfig_mapping("qnnpack")

example_inputs = (torch.rand(1,1, 50), torch.rand(1,1, 50), torch.randn(1, 1, 50))
prepare_custom_config = PrepareCustomConfig().set_float_to_observed_mapping(
    torch.nn.LSTM, UserObservedLSTM
)
convert_custom_config = ConvertCustomConfig().set_observed_to_quantized_mapping(
    torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM
)
prepared = prepare_fx(
    m,
    qconfig_mapping,
    example_inputs,
    prepare_custom_config,
    backend_config=my_backend_config,
)
prepared.print_readable()

I even changed the qparams with the one in here pytorch/torch/ao/quantization/backend_config/_common_operator_config_utils.py at main · pytorch/pytorch · GitHub for sigmoid and tanh operations but still the quantization is not done properly.
Output:

/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.10/site-packages/torch/ao/quantization/fx/utils.py:829: UserWarning: QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize for fixed qparams ops, ignoring QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f042ee47f40>}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f042ee47f40>}).
Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    model = prepare_fx(model, qconfig_mapping, example_inputs)
  warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
class GraphModule(torch.nn.Module):
    def forward(self, inputs : torch.Tensor, h0 : torch.Tensor, c0 : torch.Tensor):
        # No stacktrace found for following nodes
        activation_post_process_0 = self.activation_post_process_0(inputs);  inputs = None
        activation_post_process_1 = self.activation_post_process_1(h0);  h0 = None
        activation_post_process_2 = self.activation_post_process_2(c0);  c0 = None
        
        # File: /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/custom_lstm_models/custom_LSTM_PyTorchAPI_example.py:32, code: x = self.my_lstm(inputs, (h0, c0))
        my_lstm = self.my_lstm(activation_post_process_0, (activation_post_process_1, activation_post_process_2));  activation_post_process_0 = activation_post_process_1 = activation_post_process_2 = None
        
        # No stacktrace found for following nodes
        getitem = my_lstm[0]
        dequant_stub_0 = self.dequant_stub_0(getitem);  getitem = None
        getitem_1 = my_lstm[1];  my_lstm = None
        getitem_2 = getitem_1[0]
        dequant_stub_1 = self.dequant_stub_1(getitem_2);  getitem_2 = None
        getitem_3 = getitem_1[1];  getitem_1 = None
        dequant_stub_2 = self.dequant_stub_2(getitem_3);  getitem_3 = None
        tuple_1 = tuple([dequant_stub_1, dequant_stub_2]);  dequant_stub_1 = dequant_stub_2 = None
        tuple_2 = tuple([dequant_stub_0, tuple_1]);  dequant_stub_0 = tuple_1 = None
        return tuple_2

There is no quant_per_tensot() operation so I assume that quantization is skipped because of the heterogeneity of the configs.
Can you find a way to make this work, please? Or is there documentation I can read?

I really appreciate your help!
Thank you,