Environment:
torch==2.6.0
executorch==0.5.0
Description:
I’m working on quantization of LSTM module in ExecuTorch using PyTorch 2 Export Quantization mode. I perform Post Training Quantization. The nn.LSTM
is exported to lstm.input
Aten operation, for which I have a quantizer. In torch/ao/quantization/quantize_pt2e.py
in prepare_pt2e()
model is annotated. The problem is encountered in torch/ao/quantization/prepare.py#L501-L503 in function _maybe_insert_input_and_output_observers_for_node()
. This code is working only with operation node with 1 output. LSTM has 3 outputs. The node.meta["val"]
value is a list of FakeTensor here. To quantize recurrent layers, support for multiple outputs is needed.
I’ve read Quantization docs and noticed support for Eager Mode Quantization and FX Graph Mode Quantization and even links to examples. The issue here is to add observers for multiple outputs.
Is quantization of operation nodes with multiple outputs possible in current PT2E workflow? Are there any further problems or is the limitation of 1 output only because of the progress simply hasn’t reached there yet?
Please @jerryzh168 could you look at this?
Code to reproduce:
(runnable from executorch repo)
import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Type
from torch import fx
from torch import nn
from torch._ops import OpOverload
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer, SharedQuantizationSpec
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
is_annotated,
no_outside_users,
)
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
@dataclass
class PartitionAnchors:
"""
All fields are lists of (node, input_node), where node is from
the given partition and input_node is an input to the partition.
Quantizer uses inputs, weights and biases for quantization annotation. The others
field contains tensor inputs that aren't quantized, and the literals fields contains
is used for other types of input values as well as handling default parameters.
"""
inputs: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
weights: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
biases: list[tuple[fx.Node, fx.Node] | tuple[fx.Node, fx.Node, DerivedQuantizationSpec]] = field(default_factory=list)
others: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
literals: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
output: list[tuple[fx.Node] | tuple[fx.Node, SharedQuantizationSpec]] = field(default_factory=list)
class CustomQuantizationPattern(ABC):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions_aten.
"""
pass
@abstractmethod
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass
class LstmInputPattern(CustomQuantizationPattern):
"""
Quantization pattern for Lstm Input quantization. Accepts 3 input nodes.
Basic quantization for all inputs and outputs.
"""
def partition_types(self) -> list[Type[OpOverload]]:
return [torch.ops.aten.lstm.input]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
lstm_node = fused_partition[0].nodes[-1]
hidden_state_0, cell_state_0 = lstm_node.args[1][0], lstm_node.args[1][1]
inputs = [(lstm_node, lstm_node.args[0]), (lstm_node, hidden_state_0), (lstm_node, cell_state_0)]
weights = [(lstm_node, node) for node in lstm_node.args[2] if "weight" in node.target]
weights_edges = [(x, y) for y, x in weights]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(lstm_node.args[0], lstm_node),
(hidden_state_0, lstm_node),
(cell_state_0, lstm_node),
*weights_edges,
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
biases = [(lstm_node, node, bias_qspec) for node in lstm_node.args[2] if "bias" in node.target]
return PartitionAnchors(
inputs=inputs,
weights=weights,
biases=biases,
output=[(lstm_node,), (lstm_node,), (lstm_node,)],
)
class CustomAtenQuantizer(Quantizer):
def __init__(
self, pattern: CustomQuantizationPattern, quantization_config: QuantizationConfig
) -> None:
super().__init__()
self.pattern = pattern
self.quantization_config = quantization_config
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
fused_partitions = find_sequential_partitions_aten(
model,
self.pattern.partition_types(),
)
input_act_qspec = self.quantization_config.input_activation
weight_qspec = self.quantization_config.weight
bias_qspec = self.quantization_config.bias
output_act_qspec = self.quantization_config.output_activation
for fused_partition in fused_partitions:
if not no_outside_users(fused_partition):
continue
anchors = self.pattern.get_anchors(model, fused_partition)
if not anchors:
continue
if is_annotated(
[
x[0]
for x in anchors.inputs
+ anchors.weights
+ anchors.biases
+ anchors.output
]
):
continue
for output, *custom_spec in anchors.output:
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
)
def annotate_inputs(
inputs: list[tuple[fx.Node , fx.Node]] | list[tuple[fx.Node, fx.Node, DerivedQuantizationSpec]],
spec: Optional[QuantizationSpec],
) -> None:
for node, input_node, *custom_spec in inputs:
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
# pyre-ignore[16]: no attribute
annotation.input_qspec_map[input_node] = (
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation
annotate_inputs(anchors.inputs, input_act_qspec)
annotate_inputs(anchors.weights, weight_qspec)
# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.biases, bias_qspec)
return model
def validate(self, model: fx.GraphModule) -> None:
pass
act_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2 ** -12),
)
wgt_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
ch_axis=0
)
class CustomComposableQuantizer(ComposableQuantizer):
def __init__(self):
static_qconfig = QuantizationConfig(
act_qspec,
act_qspec,
wgt_qspec,
None,
)
super().__init__([CustomAtenQuantizer(LstmInputPattern(), static_qconfig),])
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for quantizer in self.quantizers:
quantizer.annotate(model)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
return super().validate(model)
def test_lstm_model():
input_shape = (1, 64)
model = nn.LSTM(64, hidden_size=16, bidirectional=False, bias=True, num_layers=1)
random_tensors = (torch.randn(input_shape),)
calibration_inputs = [random_tensors, random_tensors]
example_input = (torch.ones(input_shape),)
exir_program_aten = torch._export.capture_pre_autograd_graph(model, example_input)
quantizer = CustomComposableQuantizer()
m = prepare_pt2e(exir_program_aten, quantizer)
for i, data in enumerate(calibration_inputs):
m(*data)
m = convert_pt2e(m)
print(m) # no quantize nodes after lstm.input