Quantizer Backend for Linear Op intermittent failures (ExecuTorch)

Hi folks,

I was trying to use the updated quantization APIs to quantize a mock model for ExecuTorch consumption. The tutorial for Resnet linked in the relevant wiki failed with PyTorch 2.3. I have posted a comment there noting the failure.
Nevertheless I tried to proceed with this tutorial on a simple model (just two linear ops). Here’s the code:

import torch
import torch.nn as nn
torch.manual_seed(0)


class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(output_size, input_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


class FullModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(FullModel, self).__init__()
        self.simple_model = SimpleModel(input_size, output_size)

    def forward(self, x):
        output = self.simple_model(x)
        return output

def get_simple_model():
    input_size = 364
    output_size = 728
    model = FullModel(input_size, output_size)
    output = model(torch.randn(1, 3, 364, 364))
    print(model)
    return model


m = get_simple_model()
example_inputs = (torch.randn(1, 3, 364, 364),)

# quantizer code adapted from https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
import copy
import itertools
import operator
from typing import Callable, Dict, List, Optional, Set, Any
import torch._dynamo as torchdynamo
from torch.ao.quantization.quantize_pt2e import (
    convert_pt2e,
    prepare_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import(
    get_input_act_qspec,
    get_output_act_qspec,
    get_bias_qspec,
    get_weight_qspec,
    OperatorConfig,
    QuantizationConfig,
    QuantizationSpec)
from torch.ao.quantization.observer import (
    HistogramObserver,
    PerChannelMinMaxObserver,
    PlaceholderObserver,
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from torch.ao.quantization.quantizer.quantizer import (
    Quantizer,
    QuantizationAnnotation,
    SharedQuantizationSpec
)
from torch.fx import Node
from torch.ao.quantization.quantizer.utils import (
    _annotate_input_qspec_map,
    _annotate_output_qspec)

def _mark_nodes_as_annotated(nodes: List[Node]):
    for node in nodes:
        if node is not None:
            if "quantization_annotation" not in node.meta:
                node.meta["quantization_annotation"] = QuantizationAnnotation()
            node.meta["quantization_annotation"]._annotated = True

def _is_annotated(nodes: List[Node]):
    annotated = False
    for node in nodes:
        annotated = annotated or (
            "quantization_annotation" in node.meta
            and node.meta["quantization_annotation"]._annotated
        )
    return annotated

class AZQuantizer(Quantizer):

    def __init__(self):
        super().__init__()
        self.global_config: QuantizationConfig = None  # type: ignore[assignment]
        self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}

    def set_global(self, quantization_config: QuantizationConfig):
        """set global QuantizationConfig used for the backend.
        QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py.
        """
        self.global_config = quantization_config
        return self

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        """annotate nodes in the graph with observer or fake quant constructors
        to convey the desired way of quantization.
        """
        global_config = self.global_config
        self.annotate_symmetric_config(model, global_config)

        return model

    def annotate_symmetric_config(
        self, model: torch.fx.GraphModule, config: QuantizationConfig
    ) -> torch.fx.GraphModule:
        self._annotate_linear(model, config)
        return model

    def _annotate_linear(
        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
    ) -> None:
        module_partitions = get_source_partitions(
            gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
        )
        act_qspec = get_input_act_qspec(quantization_config)
        weight_qspec = get_weight_qspec(quantization_config)
        bias_qspec = get_bias_qspec(quantization_config)
        for module_or_fn_type, partitions in module_partitions.items():
            if module_or_fn_type == torch.nn.Linear:
                for p in partitions:
                    act_node = p.input_nodes[0]
                    output_node = p.output_nodes[0]
                    weight_node = None
                    bias_node = None
                    for node in p.params:
                        weight_or_bias = getattr(gm, node.target)  # type: ignore[arg-type]
                        if weight_or_bias.ndim == 2:  # type: ignore[attr-defined]
                            weight_node = node
                        if weight_or_bias.ndim == 1:  # type: ignore[attr-defined]
                            bias_node = node
                    if weight_node is None:
                        raise ValueError("No weight found in Linear pattern")
                    # find use of act node within the matched pattern
                    act_use_node = None
                    for node in p.nodes:
                        if node in act_node.users:  # type: ignore[union-attr]
                            act_use_node = node
                            break
                    if act_use_node is None:
                        raise ValueError(
                            "Could not find an user of act node within matched pattern."
                        )
                    if _is_annotated([act_use_node]) is False:  # type: ignore[list-item]
                        _annotate_input_qspec_map(
                            act_use_node,
                            act_node,
                            act_qspec,
                        )
                    if bias_node and _is_annotated([bias_node]) is False:
                        _annotate_output_qspec(bias_node, bias_qspec)
                    if _is_annotated([weight_node]) is False:  # type: ignore[list-item]
                        _annotate_output_qspec(weight_node, weight_qspec)
                    if _is_annotated([output_node]) is False:
                        _annotate_output_qspec(output_node, act_qspec)
                    nodes_to_mark_annotated = list(p.nodes)
                    _mark_nodes_as_annotated(nodes_to_mark_annotated)

    def validate(self, model: torch.fx.GraphModule) -> None:
        """validate if the annotated graph is supported by the backend"""
        pass

    @classmethod
    def get_supported_operators(cls) -> List[OperatorConfig]:
        return []

def get_symmetric_quantization_config():
    act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
        HistogramObserver
    act_quantization_spec = QuantizationSpec(
        dtype=torch.int8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
        is_dynamic=False,
        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
    )

    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver
    extra_args: Dict[str, Any] = {"eps": 2**-12}
    weight_quantization_spec = QuantizationSpec(
        dtype=torch.int8,
        quant_min=-127,
        quant_max=127,
        qscheme=torch.per_channel_symmetric,
        ch_axis=0,
        is_dynamic=False,
        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
    )

    bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
    bias_quantization_spec = QuantizationSpec(
        dtype=torch.float,
        observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
    )
    quantization_config = QuantizationConfig(
        act_quantization_spec,
        act_quantization_spec,
        weight_quantization_spec,
        bias_quantization_spec,
    )
    return quantization_config

m_copy = copy.deepcopy(m)
m, guards = torchdynamo.export(
    m,
    *copy.deepcopy(example_inputs),
    aten_graph=True,
)
quantizer = AZQuantizer()
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
m = prepare_pt2e(m, quantizer)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
print("converted module is: {}".format(m), flush=True)
print("original module is: {}".format(m_copy), flush=True)

I basically ported the Quantizer code from the tutorial and only annotated the relevant op (Linear). I am seeing an intermittent bug when running this code:

Traceback (most recent call last):
  File "/me/anaconda3/lib/python3.9/site-packages/torch/fx/passes/infra/pass_manager.py", line 270, in __call__
    res = fn(module)
  File "/me/anaconda3/lib/python3.9/site-packages/torch/fx/passes/infra/pass_base.py", line 40, in __call__
    res = self.call(graph_module)
  File "/me/anaconda3/lib/python3.9/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py", line 197, in call
    _port_metadata_for_output_quant_nodes(node, output_qspec)
  File "/me/anaconda3/lib/python3.9/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py", line 124, in _port_metadata_for_output_quant_nodes
    raise InternalError(f"Expecting {node} to have single user")
torch._export.error.InternalError: Expecting sym_size_int to have single user

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "simple_model.py", line 223, in <module>
    m = convert_pt2e(m)
  File "/me/anaconda3/lib/python3.9/site-packages/torch/ao/quantization/quantize_pt2e.py", line 239, in convert_pt2e
    model = pm(model).graph_module
  File "/me/anaconda3/lib/python3.9/site-packages/torch/fx/passes/infra/pass_manager.py", line 296, in __call__
    raise Exception(msg) from e
Exception: An error occurred when running the 'PortNodeMetaForQDQ' pass after the following passes: []

I am able to successfully convert the model (using prepare_pt2e and convert_pt2e) sometimes, producing a Aten graph with quantized operators. But I do see this issue other times.

Any ideas on what I am doing wrong here?

Also: does ExecuTorch require this form of quantization in its compilation flow? Or can we do the standard FakeQuant quantization?

@jerryzh168 or @kimishpatel might be able to help answer this.

Can you replace the use of torchdynamo.export with capture_pre_autograd like below

m = torch._export.capture_pre_autograd_graph(
    m_copy,
    *copy.deepcopy(example_inputs),
)

Quantization API requires to operate on what is called “pre_dispatch” IR and not core aten IR which is what torchdynamo.export probably returns.

Thanks, this worked. Wondering if you can answer this question?
Also: does ExecuTorch require this form of quantization in its compilation flow? Or can we do the standard FakeQuant quantization?

What do you mean by Or can we do the standard FakeQuant quantization?

The FQs here:
https://pytorch.org/docs/stable/generated/torch.ao.quantization.fake_quantize.FakeQuantize.html

Do we need a custom Quantizer backend for an accelerator? Or can we ingest quantized models with FQs directly?