Hello,
I am working on quantizing LSTM using custom module quantization. This is the ObservedLSTM module:
class ObservedLSTM(torch.ao.nn.quantizable.LSTM):
"""
the observed LSTM layer. This module needs to define a from_float function which defines
how the observed module is created from the original fp32 module.
"""
@classmethod
def from_float(cls, float_lstm):
assert isinstance(m.my_lstm, cls._FLOAT_MODULE)
# uint16, [-16, 16)
linear_output_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
)
# uint16, [0, 1)
sigmoid_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
)
# uint16, [-1, 1)
tanh_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
)
# int16, [-16, 16)
cell_state_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
)
# uint8, [-1, 1)
hidden_state_obs_ctr = observer.FixedQParamsObserver.with_args(
scale=2**-7, zero_point=2**7, dtype=torch.quint8
)
example_inputs = (torch.rand(1, 50), (torch.rand(1, 50), torch.randn(1, 50)))
return lstm_utils._get_lstm_with_individually_observed_parts(
float_lstm=float_lstm,
example_inputs=example_inputs,
backend_config=my_backend_config,
linear_output_obs_ctr=linear_output_obs_ctr,
sigmoid_obs_ctr=sigmoid_obs_ctr,
tanh_obs_ctr=tanh_obs_ctr,
cell_state_obs_ctr=cell_state_obs_ctr,
hidden_state_obs_ctr=hidden_state_obs_ctr,
)
I made sure that qparams are aligned with _FIXED_QPARAMS_OP_TO_CONSTRAINTS as I am using FixedQParamsObserver
.
I even add these configs:
conf_linear = qconfig.QConfig(
activation=observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
),
weight=observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
),
)
conf_sigmoid = qconfig.QConfig(
activation=observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
),
weight=observer.FixedQParamsObserver.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8
),
)
conf_tanh = qconfig.QConfig(
activation=observer.FixedQParamsObserver.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
),
weight=observer.FixedQParamsObserver.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8
),
I still got this warning:
UserWarning: QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize for fixed qparams ops, ignoring QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f928dc8da20>}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f928dc8da20>}).
Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
model = prepare_fx(model, qconfig_mapping, example_inputs)
warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
From what I know the .set_object_type()
overrides .set_global()
configs.
I am doing sthg wrong?