[pt2e][quant] Quantization of operators with multiple outputs (RNN, LSTM)

Environment:
torch==2.6.0
executorch==0.5.0

Description:
I’m working on quantization of LSTM module in ExecuTorch using PyTorch 2 Export Quantization mode. I perform Post Training Quantization. The nn.LSTM is exported to lstm.input Aten operation, for which I have a quantizer. In torch/ao/quantization/quantize_pt2e.py in prepare_pt2e() model is annotated. The problem is encountered in torch/ao/quantization/prepare.py#L501-L503 in function _maybe_insert_input_and_output_observers_for_node(). This code is working only with operation node with 1 output. LSTM has 3 outputs. The node.meta["val"] value is a list of FakeTensor here. To quantize recurrent layers, support for multiple outputs is needed.

I’ve read Quantization docs and noticed support for Eager Mode Quantization and FX Graph Mode Quantization and even links to examples. The issue here is to add observers for multiple outputs.

Is quantization of operation nodes with multiple outputs possible in current PT2E workflow? Are there any further problems or is the limitation of 1 output only because of the progress simply hasn’t reached there yet?

Please @jerryzh168 could you look at this?

Code to reproduce:
(runnable from executorch repo)

import torch

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Type

from torch import fx
from torch import nn
from torch._ops import OpOverload
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer, SharedQuantizationSpec
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
    QuantizationAnnotation,
    QuantizationConfig,
    QuantizationSpec,
)

from executorch.backends.cadence.aot.quantizer.utils import (
    find_sequential_partitions_aten,
    is_annotated,
    no_outside_users,
)
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams


@dataclass
class PartitionAnchors:
    """
    All fields are lists of (node, input_node), where node is from
    the given partition and input_node is an input to the partition.

    Quantizer uses inputs, weights and biases for quantization annotation. The others
    field contains tensor inputs that aren't quantized, and the literals fields contains
    is used for other types of input values as well as handling default parameters.
    """

    inputs: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
    weights: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
    biases: list[tuple[fx.Node, fx.Node] | tuple[fx.Node, fx.Node, DerivedQuantizationSpec]] = field(default_factory=list)
    others: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
    literals: list[tuple[fx.Node, fx.Node]] = field(default_factory=list)
    output: list[tuple[fx.Node] | tuple[fx.Node, SharedQuantizationSpec]] = field(default_factory=list)


class CustomQuantizationPattern(ABC):
    @abstractmethod
    def partition_types(self) -> list[OpOverload]:
        """
        List of types to be passed to find_sequential_partitions_aten.
        """
        pass

    @abstractmethod
    def get_anchors(
            self, gm: torch.fx.GraphModule, fused_partition: list[fx.GraphModule]
    ) -> Optional[PartitionAnchors]:
        pass


class LstmInputPattern(CustomQuantizationPattern):
    """
    Quantization pattern for Lstm Input quantization. Accepts 3 input nodes.

    Basic quantization for all inputs and outputs.
    """

    def partition_types(self) -> list[Type[OpOverload]]:
        return [torch.ops.aten.lstm.input]

    def get_anchors(
            self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
    ) -> PartitionAnchors | None:
        lstm_node = fused_partition[0].nodes[-1]
        hidden_state_0, cell_state_0 = lstm_node.args[1][0], lstm_node.args[1][1]
        inputs = [(lstm_node, lstm_node.args[0]), (lstm_node, hidden_state_0), (lstm_node, cell_state_0)]
        weights = [(lstm_node, node) for node in lstm_node.args[2] if "weight" in node.target]
        weights_edges = [(x, y) for y, x in weights]

        bias_qspec = DerivedQuantizationSpec(
            derived_from=[
                (lstm_node.args[0], lstm_node),
                (hidden_state_0, lstm_node),
                (cell_state_0, lstm_node),
                *weights_edges,
            ],
            derive_qparams_fn=get_bias_qparams,
            dtype=torch.int32,
            quant_min=-(2**31),
            quant_max=2**31 - 1,
            qscheme=torch.per_tensor_affine,
        )
        biases = [(lstm_node, node, bias_qspec) for node in lstm_node.args[2] if "bias" in node.target]

        return PartitionAnchors(
            inputs=inputs,
            weights=weights,
            biases=biases,
            output=[(lstm_node,), (lstm_node,), (lstm_node,)],
        )


class CustomAtenQuantizer(Quantizer):
    def __init__(
            self, pattern: CustomQuantizationPattern, quantization_config: QuantizationConfig
    ) -> None:
        super().__init__()
        self.pattern = pattern
        self.quantization_config = quantization_config

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        fused_partitions = find_sequential_partitions_aten(
            model,
            self.pattern.partition_types(),
        )

        input_act_qspec = self.quantization_config.input_activation
        weight_qspec = self.quantization_config.weight
        bias_qspec = self.quantization_config.bias
        output_act_qspec = self.quantization_config.output_activation

        for fused_partition in fused_partitions:
            if not no_outside_users(fused_partition):
                continue

            anchors = self.pattern.get_anchors(model, fused_partition)
            if not anchors:
                continue
            if is_annotated(
                    [
                        x[0]
                        for x in anchors.inputs
                                 + anchors.weights
                                 + anchors.biases
                                 + anchors.output
                    ]
            ):
                continue

            for output, *custom_spec in anchors.output:
                # pyre-ignore[16]: no attribute
                output.meta["quantization_annotation"] = QuantizationAnnotation(
                    # pyre-ignore[6]: incompatible parameter type
                    output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
                    _annotated=True,
                )

            def annotate_inputs(
                    inputs: list[tuple[fx.Node , fx.Node]] | list[tuple[fx.Node, fx.Node, DerivedQuantizationSpec]],
                    spec: Optional[QuantizationSpec],
            ) -> None:
                for node, input_node, *custom_spec in inputs:
                    # pyre-ignore[16]: no attribute
                    annotation = node.meta.get(
                        "quantization_annotation",
                        QuantizationAnnotation(_annotated=True),
                    )
                    # pyre-ignore[16]: no attribute
                    annotation.input_qspec_map[input_node] = (
                        custom_spec[0] if custom_spec else spec
                    )
                    # pyre-ignore[16]: no attribute
                    node.meta["quantization_annotation"] = annotation

            annotate_inputs(anchors.inputs, input_act_qspec)
            annotate_inputs(anchors.weights, weight_qspec)
            # pyre-ignore[6]: incompatible parameter type
            annotate_inputs(anchors.biases, bias_qspec)
        return model

    def validate(self, model: fx.GraphModule) -> None:
        pass


act_qspec = QuantizationSpec(
    dtype=torch.int8,
    quant_min=-128,
    quant_max=127,
    qscheme=torch.per_tensor_affine,
    is_dynamic=False,
    observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2 ** -12),
)

wgt_qspec = QuantizationSpec(
    dtype=torch.int8,
    quant_min=-128,
    quant_max=127,
    qscheme=torch.per_tensor_symmetric,
    is_dynamic=False,
    observer_or_fake_quant_ctr=MinMaxObserver,
    ch_axis=0
)

class CustomComposableQuantizer(ComposableQuantizer):
    def __init__(self):
        static_qconfig = QuantizationConfig(
            act_qspec,
            act_qspec,
            wgt_qspec,
            None,
        )
        super().__init__([CustomAtenQuantizer(LstmInputPattern(), static_qconfig),])

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        for quantizer in self.quantizers:
            quantizer.annotate(model)
        return model

    def validate(self, model: torch.fx.GraphModule) -> None:
        return super().validate(model)
    
    
def test_lstm_model():
    input_shape = (1, 64)
    model = nn.LSTM(64, hidden_size=16, bidirectional=False, bias=True, num_layers=1)
    random_tensors  = (torch.randn(input_shape),)
    calibration_inputs = [random_tensors, random_tensors]
    example_input = (torch.ones(input_shape),)

    exir_program_aten = torch._export.capture_pre_autograd_graph(model, example_input)

    quantizer = CustomComposableQuantizer()

    m = prepare_pt2e(exir_program_aten, quantizer)
    for i, data in enumerate(calibration_inputs):
        m(*data)
    m = convert_pt2e(m)

    print(m) # no quantize nodes after lstm.input

for static quant for LSTM op, we’d need to break it down first, it has many internal states and internal linear ops, how are you going to quantize it as a whole?

what we did before to quantize SDPA op (also a large op) is the following:

import torch

from torch._export import capture_pre_autograd_graph
from torch.export import export
from torch._decomp import decomposition_table
from torch.ao.quantization.pt2e.export_utils import _WrapperModule

def fn(q, k, v):
    return torch.nn.functional.scaled_dot_product_attention(
        q.transpose(1, 2).contiguous(),
        k.transpose(1, 2),
        v.transpose(1, 2),
        scale=0.125,
    )[:2]

example_inputs = (
    torch.randn(4, 2, 4, 2, dtype=torch.float),
    torch.randn(4, 2, 4, 2, dtype=torch.float),
    torch.randn(4, 2, 4, 2, dtype=torch.float)
)

# pattern
pattern = capture_pre_autograd_graph(_WrapperModule(fn), example_inputs)
# pattern = export(_WrapperModule(fn), example_inputs)
pattern.graph.eliminate_dead_code()
# print("pattern:", pattern)


# replacement
# torch._dynamo.reset()
decomp_table = {
    torch.ops.aten._scaled_dot_product_flash_attention.default: decomposition_table[torch.ops.aten._scaled_dot_product_flash_attention.default],
}
# replacement = torch._dynamo.export(
#     fn,
#     constraints=None,
#     assume_static_by_default=True,
#     tracing_mode="symbolic",
#     decomposition_table=decomp_table,
#     pre_dispatch=False,
#     aten_graph=True,
# )(
#     *example_inputs,
# )[0]
# replacement.graph.eliminate_dead_code()
# print("post dispatch tracing:", m)
replacement = torch.export.export(_WrapperModule(fn)).run_decompositions(decomp_table)




from torch.fx import subgraph_rewriter
class M(torch.nn.Module):
    def forward(self, q, k, v):
        return fn(q, k, v)

m = M().eval()
model = capture_pre_autograd_graph(m, example_inputs)
print("before replacement:", model)
subgraph_rewriter.replace_pattern(model, pattern, replacement)
print("after replacement:", model)

some of the APIs have been updated, e.g. capture_pre_autograd_graph is updated to torch.export.export_for_training but the others are similar

basically you’ll break the LSTM op into multiiple aten.linears and then you can use the normal annotation to quantize it.

Thank you for response. The idea behind it is to quantize Lstm as a whole and then convert it to our IR with taking care of internal states (e.g. creating weight tensors). The approach to decompose Lstm and the quantize separate ops was considered, but unfortunately there is an op with multiple outputs as well. I used reimplementation of Lstm in torch.ao.nn.quantizable.LSTM() and it has chunk op in graph:

%chunk : [num_users=4] = call_function[target=torch.ops.aten.chunk.default](args = (%add, 4, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%chunk, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%chunk, 1), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%chunk, 2), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%chunk, 3), kwargs = {})

The code you posted does not run for me, I’m getting KeyError for decomposition_table:

torch.ops.aten._scaled_dot_product_flash_attention.default: decomposition_table[torch.ops.aten._scaled_dot_product_flash_attention.default],
KeyError: <OpOverload(op='aten._scaled_dot_product_flash_attention', overload='default')>

I tried to implement changes in PyTorch that would enable multiple output quantization, so far I got past torch.ao.quantization.quantize_pt2e.prepare_pt2e(). Basically I made changes to allow tuples and add observers for every output in files torch/ao/quantization/pt2e/prepare.py and torch/fx/node.py. It works until the step in which calibration data is passed through the model, as observers are not ready to accept collections of Tensors (due to multiple outputs). What you think about implementing support for this in observers? Probably some changes would also be needed in torch/ao/quantization/quantize_pt2e.py.