LSTM PTSQ using Eager Mode

I am tying to quantize an LSTM layer using PTSQ with Eager Mode and the export it to ONNX.
I am using “qnnpack” default configs.

import torch.nn

class M(torch.nn.Module):
    def __init__(self,h0,c0):

        super().__init__()
        
        self.h0 = h0
        self.c0 = c0
        
        self.quant = torch.ao.quantization.QuantStub()
        self.rnn = torch.nn.LSTM(10, 20, 2)
        self.dequant = torch.ao.quantization.DeQuantStub()
        
    def forward(self, x):
        # during the convert step, this will be replaced with a
        # `quantize_per_tensor` call
        x = self.quant(x)
        x = self.rnn( x , (self.h0, self.c0) )
        x = self.dequant(x)
        return x
    
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
dummy_input = torch.rand(5, 3, 10)    

test_fp32 = M(h0,c0).eval()

output, (h0, c0) = test_fp32(dummy_input)


test_fp32.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')

model_fp32_prepared = torch.ao.quantization.prepare(test_fp32                                            
                                                    )

model_fp32_prepared(dummy_input)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)


torch.onnx.export(model_int8,(dummy_input),f= "EM_lstm_PTSQ.onnx")

I got this log error

Traceback (most recent call last):
  File "/home/ahmed/Desktop/pulse_ai/scripts/tiny/PyTorch_API/PTQ/EM_lstm_quant.py", line 34, in <module>
    test_fp32.qconfig = torch.ao.quantization.get_default_qconfig('')
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/quantization/qconfig.py", line 234, in get_default_qconfig
    raise AssertionError(
AssertionError: backend:  not supported. backend must be one of ['fbgemm', 'x86', 'qnnpack', 'onednn']
ahmed@ahmed-OMEN-Laptop-15-ek0xxx:~/Desktop$ /bin/python /home/ahmed/Desktop/pulse_ai/scripts/tiny/PyTorch_API/PTQ/EM_lstm_quant.py
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "/home/ahmed/Desktop/pulse_ai/scripts/tiny/PyTorch_API/PTQ/EM_lstm_quant.py", line 43, in <module>
    torch.onnx.export(model_int8,(dummy_input),f= "EM_lstm_PTSQ.onnx")
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1112, in _model_to_graph
    model = _pre_trace_quant_model(model, args)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1067, in _pre_trace_quant_model
    return torch.jit.trace(model, args)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ahmed/Desktop/pulse_ai/scripts/tiny/PyTorch_API/PTQ/EM_lstm_quant.py", line 21, in forward
    x = self.rnn( x , (self.h0, self.c0) )
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/nn/quantizable/modules/rnn.py", line 363, in forward
    x, (h, c) = layer(x, hxcx[idx])
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/nn/quantizable/modules/rnn.py", line 201, in forward
    result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/nn/quantizable/modules/rnn.py", line 150, in forward
    hidden = self.cell(xx, hidden)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/nn/quantizable/modules/rnn.py", line 71, in forward
    hgates = self.hgates(hx)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/ao/nn/quantized/modules/linear.py", line 168, in forward
    return torch.ops.quantized.linear(
  File "/home/ahmed/.local/lib/python3.10/site-packages/torch/_ops.py", line 502, in __call__
    return self._op(*args, **kwargs or {})
NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear' is only available for these backends: [QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

QuantizedCPU: registered at ../aten/src/ATen/native/quantized/cpu/qlinear.cpp:990 [kernel]
QuantizedCUDA: registered at ../aten/src/ATen/native/quantized/cudnn/Linear.cpp:360 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:491 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:280 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradOther: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:30 [backend fallback]
AutogradCPU: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:34 [backend fallback]
AutogradCUDA: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:42 [backend fallback]
AutogradXLA: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:46 [backend fallback]
AutogradMPS: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:54 [backend fallback]
AutogradXPU: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:38 [backend fallback]
AutogradHPU: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:67 [backend fallback]
AutogradLazy: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:50 [backend fallback]
AutogradMeta: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:58 [backend fallback]
Tracer: registered at ../torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:487 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:354 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:815 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1073 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:152 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:487 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]

What does this mean?

Could not run 'quantized::linear' with arguments from the 'CPU' backend.

Is this from ONNX export or I should just make custom quantzation configs (or even custom LSTM layer)

Thank you,

this means you are feeding a CPU tensor to a quantized operator, maybe you can try quantizing self.h0 self.c0 as well?

btw in case you did not find it, here is the doc for custom module quantization: Quantization — PyTorch 2.0 documentation