Hello,
I am working on quantizing LSTM layers using PTSQ with torch.fx . I am working with custom LSTM module as mentioned here pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm
I have just copy paste the example:
import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnquantized
import torch.ao.nn.quantizable as nnquantizable
import torch.ao.quantization.observer as observer
from torch.ao.quantization.backend_config import qnnpack, DTypeConfig, get_qnnpack_backend_config
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig, ConvertCustomConfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qconfig
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_lstm = torch.nn.LSTM(50, 50, 1)
def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
x = self.my_lstm(inputs, (h0, c0))
return x
# Construct a BackendConfig that supports qint32 for certain ops
# TODO: build a BackendConfig from scratch instead of modifying an existing one
qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
my_backend_config = get_qnnpack_backend_config()
for config in my_backend_config.configs:
if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
config.add_dtype_config(qint32_dtype_config)
class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
"""
Example of user provided LSTM implementation that assigns fixed qparams
to the inner ops.
"""
@classmethod
def from_float(cls, float_lstm):
assert isinstance(float_lstm, cls._FLOAT_MODULE)
# uint16, [-16, 16)
linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
# uint16, [0, 1)
sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32)
# uint16, [-1, 1)
tanh_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
# int16, [-16, 16)
cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32)
# uint8, [-1, 1)
hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50)))
return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
float_lstm=float_lstm,
example_inputs=example_inputs,
backend_config=my_backend_config,
linear_output_obs_ctr=linear_output_obs_ctr,
sigmoid_obs_ctr=sigmoid_obs_ctr,
tanh_obs_ctr=tanh_obs_ctr,
cell_state_obs_ctr=cell_state_obs_ctr,
hidden_state_obs_ctr=hidden_state_obs_ctr,
)
class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
"""
Example of user provided LSTM implementation that produces a reference
quantized module from a `UserObservedLSTM`.
"""
@classmethod
def from_observed(cls, observed_lstm):
assert isinstance(observed_lstm, cls._FLOAT_MODULE)
return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module(
observed_lstm=observed_lstm,
backend_config=my_backend_config,
)
# FX graph mode quantization
m = MyModel()
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
prepare_custom_config = PrepareCustomConfig() \
.set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM)
convert_custom_config = ConvertCustomConfig() \
.set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM)
prepared = prepare_fx(
m,
qconfig_mapping,
example_inputs,
prepare_custom_config,
backend_config=my_backend_config,
)
prepared(*example_inputs)
converted = convert_fx(
prepared,
convert_custom_config,
backend_config=my_backend_config,
)
converted(*example_inputs)
and I got this error
/home/ahmed/anaconda3/envs/test_nntool/bin/python /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py
Traceback (most recent call last):
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 83, in <module>
prepared = prepare_fx(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 382, in prepare_fx
return _prepare_fx(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 144, in _prepare_fx
prepared = prepare(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1842, in prepare
result_node = insert_observers_for_model(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1582, in insert_observers_for_model
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1277, in _swap_custom_module_to_observed
observed_custom_module_class.from_float(custom_module)
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 52, in from_float
return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
AttributeError: module 'torch.ao.quantization.fx' has no attribute 'lstm_utils'. Did you mean: 'match_utils'?
(test_nntool) ahmed@LAPTOP-71LI719D:~$ conda activate torch_custom_lstm
(torch_custom_lstm) ahmed@LAPTOP-71LI719D:~$ /home/ahmed/anaconda3/envs/torch_custom_lstm/bin/python /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py
2.1.0.dev20230622+cu118
Traceback (most recent call last):
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 83, in <module>
prepared = prepare_fx(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 382, in prepare_fx
return _prepare_fx(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 144, in _prepare_fx
prepared = prepare(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1842, in prepare
result_node = insert_observers_for_model(
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1582, in insert_observers_for_model
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/fx/prepare.py", line 1277, in _swap_custom_module_to_observed
observed_custom_module_class.from_float(custom_module)
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/PyTorch_API/PTQ/test.py", line 52, in from_float
return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
AttributeError: module 'torch.ao.quantization.fx' has no attribute 'lstm_utils'. Did you mean: 'match_utils'?
I am using torch version: 2.1.0.dev20230622+cu118 (nightly)
Looking for your help!