Custom LSTM PTSQ QconfigMapping

Hello,
I am working on quantizing LSTM using custom module quantization. This is the ObservedLSTM module:

class ObservedLSTM(torch.ao.nn.quantizable.LSTM):
    """
     the observed LSTM layer. This module needs to define a from_float function which defines
     how the observed module is created from the original fp32 module.
    """
    @classmethod
    def from_float(cls, float_lstm):
        assert isinstance(m.my_lstm, cls._FLOAT_MODULE)
        # uint16, [-16, 16)
        linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
        )
        # uint16, [0, 1)
        sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
        )
        # uint16, [-1, 1)
        tanh_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
        )
        # int16, [-16, 16)
        cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
        )
        # uint8, [-1, 1)
        hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(
            scale=2**-7, zero_point=2**7, dtype=torch.quint8
        )
        example_inputs = (torch.rand(1, 50), (torch.rand(1, 50), torch.randn(1, 50)))
        return lstm_utils._get_lstm_with_individually_observed_parts(
            float_lstm=float_lstm,
            example_inputs=example_inputs,
            backend_config=my_backend_config,
            linear_output_obs_ctr=linear_output_obs_ctr,
            sigmoid_obs_ctr=sigmoid_obs_ctr,
            tanh_obs_ctr=tanh_obs_ctr,
            cell_state_obs_ctr=cell_state_obs_ctr,
            hidden_state_obs_ctr=hidden_state_obs_ctr,
        )

I made sure that qparams are aligned with _FIXED_QPARAMS_OP_TO_CONSTRAINTS as I am using FixedQParamsObserver.

I even add these configs:

conf_linear = qconfig.QConfig(
    activation=observer.FixedQParamsObserver.with_args(
        scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
    ),
    weight=observer.FixedQParamsObserver.with_args(
        scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
    ),
)

conf_sigmoid = qconfig.QConfig(
    activation=observer.FixedQParamsObserver.with_args(
        scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
    ),
    weight=observer.FixedQParamsObserver.with_args(
        scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
    ),
)

conf_tanh = qconfig.QConfig(
    activation=observer.FixedQParamsObserver.with_args(
        scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
    ),
    weight=observer.FixedQParamsObserver.with_args(
        scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
    ),

I still got this warning:

UserWarning: QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize for fixed qparams ops, ignoring QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f928dc8da20>}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f928dc8da20>}).
Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    model = prepare_fx(model, qconfig_mapping, example_inputs)
  warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "

From what I know the .set_object_type() overrides .set_global() configs.
I am doing sthg wrong?

can you give a fuller repro? at what stage you do get this error?

Thank you @HDCharles for your reply!
I am working on custom module quantization for LSTM layers.
From what I understood, LSTM gates must have FixedQParamsObserver with certain ranges fo qparams as mentionned here pytorch/torch/ao/quantization/backend_config/_common_operator_config_utils.py at main · pytorch/pytorch (github.com).

Besides, I added Qconfig which is compatible with those observers:

qconfig_mapping = (
    QConfigMapping()
    .set_global(get_default_qconfig("qnnpack"))
    .set_object_type(nn.Linear, conf_linear)
    .set_object_type(nn.Sigmoid, conf_sigmoid)
    .set_object_type(nn.Tanh, conf_tanh)
)

Normally .set_global() configs are overridden by those of .set_object_type(), So normally, observers of “qnnpack” (HistogramObserver for activations & MinMaxObserver for weights) should be ignored which is normally the case!

I don’t know why I am still getting this warning!
I hope that I made it a bit clearer!

without a full repro this is hard to diagnose.

Can you explain how what you are doing differs from: https://github.com/pytorch/pytorch/blob/main/test/quantization/fx/test_quantize_fx.py#L4886

I reproduced the same test function in a script and still get the same warning (I just copy-paste it!)

import torch
import torch.nn as nn
import numpy
from copy import deepcopy

import torch.ao.quantization.observer as observer
from torch.ao.quantization.backend_config._common_operator_config_utils import (
    _FIXED_QPARAMS_OP_TO_CONSTRAINTS,
)
import torch.ao.quantization.qconfig as qconfig

from torch.ao.quantization.backend_config import qnnpack, DTypeConfig, get_qnnpack_backend_config

from torch.ao.quantization.fx.custom_config import (
    PrepareCustomConfig,
    ConvertCustomConfig,
    QConfigMapping,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qconfig
from torch.ao.quantization.fx import lstm_utils
import torch.ao.ns._numeric_suite_fx as ns
import torch.ao.ns.fx

class MyModel(torch.nn.Module):
    def __init__(self):
                super().__init__()
                self.my_lstm = torch.nn.LSTM(50, 50, 1)

    def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
                x = self.my_lstm(inputs, (h0, c0))
                return x

        # Construct a BackendConfig that supports qint32 for certain ops
        # TODO: build a BackendConfig from scratch instead of modifying an existing one
qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
my_backend_config = get_qnnpack_backend_config()
for config in my_backend_config.configs:
    if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
                config.add_dtype_config(qint32_dtype_config)

class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
    """
    Example of user provided LSTM implementation that assigns fixed qparams
    to the inner ops.
    """
    @classmethod
    def from_float(cls, float_lstm):
        assert isinstance(float_lstm, cls._FLOAT_MODULE)
                # uint16, [-16, 16)
        linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
                # uint16, [0, 1)
        sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32)
                # uint16, [-1, 1)
        tanh_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
                # int16, [-16, 16)
        cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32)
                # uint8, [-1, 1)
        hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
        example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50)))
        return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
                    float_lstm=float_lstm,
                    example_inputs=example_inputs,
                    backend_config=my_backend_config,
                    linear_output_obs_ctr=linear_output_obs_ctr,
                    sigmoid_obs_ctr=sigmoid_obs_ctr,
                    tanh_obs_ctr=tanh_obs_ctr,
                    cell_state_obs_ctr=cell_state_obs_ctr,
                    hidden_state_obs_ctr=hidden_state_obs_ctr,
                )
class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
    """
    Example of user provided LSTM implementation that produces a reference
            quantized module from a `UserObservedLSTM`.
    """
    @classmethod
    def from_observed(cls, observed_lstm):
        assert isinstance(observed_lstm, cls._FLOAT_MODULE)
        return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module(
                    observed_lstm=observed_lstm,
                    backend_config=my_backend_config,
                )

        # FX graph mode quantization
m = MyModel()
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
prepare_custom_config = PrepareCustomConfig() \
                    .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM)
convert_custom_config = ConvertCustomConfig() \
                    .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM)
prepared = prepare_fx(
            m,
            qconfig_mapping,
            example_inputs,
            prepare_custom_config,
            backend_config=my_backend_config,
        )
prepared(*example_inputs)
converted = convert_fx(
            prepared,
            convert_custom_config,
            backend_config=my_backend_config,
        )
converted(*example_inputs)

This is the warning:

UserWarning: QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize for fixed qparams ops, ignoring QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f66b4195a20>}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f66b4195a20>}).
Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    model = prepare_fx(model, qconfig_mapping, example_inputs)
  warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "

torch version: 2.0.1

can you just run the test script, does that test pass for you?

python test/test_quantization.py TestQuantizeFx.test_static_lstm_with_custom_fixed_qparams

ahh i see.

there are 2 tanh ops, one cell_gate, one tanh_cy, only the cell_gate one gets a fixed_qparams qconfig. given the comment, it looks like its a known bug

Yeah, maybe!
Could you please discard the warning message error until you fix it?
Besides, can we have an explanation of the comments here?

The dtype is set to torch.qint32 (which is aligned with the backend config) but I don’t know why it is mentioned uint16 (are you just putting an uint16 in qint32? (something with casting)).

And as @jerryzh168 mentioned, shouldn’t we use fixed qparams for at least sigmoid and tanh ops?
pytorch/torch/ao/quantization/backend_config/_common_operator_config_utils.py at main · pytorch/pytorch (github.com)

Thank you for the clarification!

its just a test simulating some weird dtypes using a qint32.

The warning is working as intended, a fixedqparams op is getting matched with your global qconfig, you could either just make your qconfig_mapping handle it or ignore the warning.