Custom LSTM statis quantization not working

Hello,
I am working on quantizing LSTM layers using PTSQ with torch.fx . I am working with custom LSTM module as mentioned here pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm
I have just copy paste the example:

import torch
import torch.nn as nn


import torch.ao.nn.quantized as nnquantized
import torch.ao.nn.quantizable as nnquantizable
import torch.ao.quantization.observer as observer

from torch.ao.quantization.backend_config import qnnpack, DTypeConfig, get_qnnpack_backend_config

from  torch.ao.quantization.fx.custom_config import PrepareCustomConfig, ConvertCustomConfig, QConfigMapping
from  torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import  get_default_qconfig_mapping, get_default_qconfig


class MyModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.my_lstm = torch.nn.LSTM(50, 50, 1)

            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
                x = self.my_lstm(inputs, (h0, c0))
                return x

        # Construct a BackendConfig that supports qint32 for certain ops
        # TODO: build a BackendConfig from scratch instead of modifying an existing one
qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
my_backend_config = get_qnnpack_backend_config()
for config in my_backend_config.configs:
    if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
        config.add_dtype_config(qint32_dtype_config)

class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
    """
    Example of user provided LSTM implementation that assigns fixed qparams
    to the inner ops.
            """
    @classmethod
    def from_float(cls, float_lstm):
                assert isinstance(float_lstm, cls._FLOAT_MODULE)
                # uint16, [-16, 16)
                linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
                # uint16, [0, 1)
                sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32)
                # uint16, [-1, 1)
                tanh_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
                # int16, [-16, 16)
                cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32)
                # uint8, [-1, 1)
                hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
                example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50)))
                return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
                    float_lstm=float_lstm,
                    example_inputs=example_inputs,
                    backend_config=my_backend_config,
                    linear_output_obs_ctr=linear_output_obs_ctr,
                    sigmoid_obs_ctr=sigmoid_obs_ctr,
                    tanh_obs_ctr=tanh_obs_ctr,
                    cell_state_obs_ctr=cell_state_obs_ctr,
                    hidden_state_obs_ctr=hidden_state_obs_ctr,
                )
class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
            """
            Example of user provided LSTM implementation that produces a reference
            quantized module from a `UserObservedLSTM`.
            """
            @classmethod
            def from_observed(cls, observed_lstm):
                assert isinstance(observed_lstm, cls._FLOAT_MODULE)
                return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module(
                    observed_lstm=observed_lstm,
                    backend_config=my_backend_config,
                )

        # FX graph mode quantization
m = MyModel()
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
prepare_custom_config = PrepareCustomConfig() \
            .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM)
convert_custom_config = ConvertCustomConfig() \
            .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM)
prepared = prepare_fx(
            m,
            qconfig_mapping,
            example_inputs,
            prepare_custom_config,
            backend_config=my_backend_config,
        )
prepared(*example_inputs)
converted = convert_fx(
            prepared,
            convert_custom_config,
            backend_config=my_backend_config,
        )
converted(*example_inputs)

and I got this error

/home/ahmed/anaconda3/envs/test_nntool/bin/python /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py
Traceback (most recent call last):
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 83, in <module>
    prepared = prepare_fx(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 382, in prepare_fx
    return _prepare_fx(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 144, in _prepare_fx
    prepared = prepare(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1842, in prepare
    result_node = insert_observers_for_model(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1582, in insert_observers_for_model
    _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1277, in _swap_custom_module_to_observed
    observed_custom_module_class.from_float(custom_module)
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 52, in from_float
    return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
AttributeError: module 'torch.ao.quantization.fx' has no attribute 'lstm_utils'. Did you mean: 'match_utils'?
(test_nntool) ahmed@LAPTOP-71LI719D:~$ conda activate torch_custom_lstm
(torch_custom_lstm) ahmed@LAPTOP-71LI719D:~$ /home/ahmed/anaconda3/envs/torch_custom_lstm/bin/python /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py
2.1.0.dev20230622+cu118
Traceback (most recent call last):
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 83, in <module>
    prepared = prepare_fx(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 382, in prepare_fx
    return _prepare_fx(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 144, in _prepare_fx
    prepared = prepare(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1842, in prepare
    result_node = insert_observers_for_model(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1582, in insert_observers_for_model
    _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1277, in _swap_custom_module_to_observed
    observed_custom_module_class.from_float(custom_module)
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 52, in from_float
    return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
AttributeError: module 'torch.ao.quantization.fx' has no attribute 'lstm_utils'. Did you mean: 'match_utils'?

I am using torch version: 2.1.0.dev20230622+cu118 (nightly)

Looking for your help!

@jerryzh168 please have a look!
Sorry bro I am being a burden :nerd_face: :sweat_smile:

it’s here: pytorch/torch/ao/quantization/fx/lstm_utils.py at main · pytorch/pytorch · GitHub, I’m not sure what is wrong, maybe you can try rebuild pytorch?

Could you please try to run this code and see if it is working correctly before I rebuild Pytorch ?

hello @jerryzh168 ,
I don’t know what happened but “lstm_utils.py” file wasn’t there ( I tried different version of PyTorch ).
Do you know if we can use this approach for QAT?

sorry I haven’t got the time to run this, could you install from source: GitHub - pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration and try again?

I build it but “lstm_utils” is not there. It worked now, I just downloaded it separately ( maybe you have to find a solution for the next releases ).

1- Anyways, now I am trying to export this custom layer on ONNX, is there any possibility to work on it and make it exportable? They are just some matrix multiplications and activation functions.
I think the problem is with the input tuple because we have to enter the inputs as long as the hidden layers.
Do you know sthg about this?

2- Do you know if this approach supports QAT ?

this does support QAT, but I can’t help with ONNX, maybe you can contact ONNX people directly (Quantization — PyTorch 2.0 documentation)