Hi folks,
I was trying to use the updated quantization APIs to quantize a mock model for ExecuTorch consumption. The tutorial for Resnet linked in the relevant wiki failed with PyTorch 2.3. I have posted a comment there noting the failure.
Nevertheless I tried to proceed with this tutorial on a simple model (just two linear ops). Here’s the code:
import torch
import torch.nn as nn
torch.manual_seed(0)
class SimpleModel(nn.Module):
def __init__(self, input_size, output_size):
super(SimpleModel, self).__init__()
self.linear1 = nn.Linear(input_size, output_size)
self.linear2 = nn.Linear(output_size, input_size)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
class FullModel(nn.Module):
def __init__(self, input_size, output_size):
super(FullModel, self).__init__()
self.simple_model = SimpleModel(input_size, output_size)
def forward(self, x):
output = self.simple_model(x)
return output
def get_simple_model():
input_size = 364
output_size = 728
model = FullModel(input_size, output_size)
output = model(torch.randn(1, 3, 364, 364))
print(model)
return model
m = get_simple_model()
example_inputs = (torch.randn(1, 3, 364, 364),)
# quantizer code adapted from https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
import copy
import itertools
import operator
from typing import Callable, Dict, List, Optional, Set, Any
import torch._dynamo as torchdynamo
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import(
get_input_act_qspec,
get_output_act_qspec,
get_bias_qspec,
get_weight_qspec,
OperatorConfig,
QuantizationConfig,
QuantizationSpec)
from torch.ao.quantization.observer import (
HistogramObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from torch.ao.quantization.quantizer.quantizer import (
Quantizer,
QuantizationAnnotation,
SharedQuantizationSpec
)
from torch.fx import Node
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec)
def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if node is not None:
if "quantization_annotation" not in node.meta:
node.meta["quantization_annotation"] = QuantizationAnnotation()
node.meta["quantization_annotation"]._annotated = True
def _is_annotated(nodes: List[Node]):
annotated = False
for node in nodes:
annotated = annotated or (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
)
return annotated
class AZQuantizer(Quantizer):
def __init__(self):
super().__init__()
self.global_config: QuantizationConfig = None # type: ignore[assignment]
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
def set_global(self, quantization_config: QuantizationConfig):
"""set global QuantizationConfig used for the backend.
QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py.
"""
self.global_config = quantization_config
return self
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""annotate nodes in the graph with observer or fake quant constructors
to convey the desired way of quantization.
"""
global_config = self.global_config
self.annotate_symmetric_config(model, global_config)
return model
def annotate_symmetric_config(
self, model: torch.fx.GraphModule, config: QuantizationConfig
) -> torch.fx.GraphModule:
self._annotate_linear(model, config)
return model
def _annotate_linear(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
)
act_qspec = get_input_act_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
bias_qspec = get_bias_qspec(quantization_config)
for module_or_fn_type, partitions in module_partitions.items():
if module_or_fn_type == torch.nn.Linear:
for p in partitions:
act_node = p.input_nodes[0]
output_node = p.output_nodes[0]
weight_node = None
bias_node = None
for node in p.params:
weight_or_bias = getattr(gm, node.target) # type: ignore[arg-type]
if weight_or_bias.ndim == 2: # type: ignore[attr-defined]
weight_node = node
if weight_or_bias.ndim == 1: # type: ignore[attr-defined]
bias_node = node
if weight_node is None:
raise ValueError("No weight found in Linear pattern")
# find use of act node within the matched pattern
act_use_node = None
for node in p.nodes:
if node in act_node.users: # type: ignore[union-attr]
act_use_node = node
break
if act_use_node is None:
raise ValueError(
"Could not find an user of act node within matched pattern."
)
if _is_annotated([act_use_node]) is False: # type: ignore[list-item]
_annotate_input_qspec_map(
act_use_node,
act_node,
act_qspec,
)
if bias_node and _is_annotated([bias_node]) is False:
_annotate_output_qspec(bias_node, bias_qspec)
if _is_annotated([weight_node]) is False: # type: ignore[list-item]
_annotate_output_qspec(weight_node, weight_qspec)
if _is_annotated([output_node]) is False:
_annotate_output_qspec(output_node, act_qspec)
nodes_to_mark_annotated = list(p.nodes)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def validate(self, model: torch.fx.GraphModule) -> None:
"""validate if the annotated graph is supported by the backend"""
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return []
def get_symmetric_quantization_config():
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
HistogramObserver
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver
extra_args: Dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
)
return quantization_config
m_copy = copy.deepcopy(m)
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
quantizer = AZQuantizer()
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
m = prepare_pt2e(m, quantizer)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
print("converted module is: {}".format(m), flush=True)
print("original module is: {}".format(m_copy), flush=True)
I basically ported the Quantizer code from the tutorial and only annotated the relevant op (Linear). I am seeing an intermittent bug when running this code:
Traceback (most recent call last):
File "/me/anaconda3/lib/python3.9/site-packages/torch/fx/passes/infra/pass_manager.py", line 270, in __call__
res = fn(module)
File "/me/anaconda3/lib/python3.9/site-packages/torch/fx/passes/infra/pass_base.py", line 40, in __call__
res = self.call(graph_module)
File "/me/anaconda3/lib/python3.9/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py", line 197, in call
_port_metadata_for_output_quant_nodes(node, output_qspec)
File "/me/anaconda3/lib/python3.9/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py", line 124, in _port_metadata_for_output_quant_nodes
raise InternalError(f"Expecting {node} to have single user")
torch._export.error.InternalError: Expecting sym_size_int to have single user
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "simple_model.py", line 223, in <module>
m = convert_pt2e(m)
File "/me/anaconda3/lib/python3.9/site-packages/torch/ao/quantization/quantize_pt2e.py", line 239, in convert_pt2e
model = pm(model).graph_module
File "/me/anaconda3/lib/python3.9/site-packages/torch/fx/passes/infra/pass_manager.py", line 296, in __call__
raise Exception(msg) from e
Exception: An error occurred when running the 'PortNodeMetaForQDQ' pass after the following passes: []
I am able to successfully convert the model (using prepare_pt2e and convert_pt2e) sometimes, producing a Aten graph with quantized operators. But I do see this issue other times.
Any ideas on what I am doing wrong here?
Also: does ExecuTorch require this form of quantization in its compilation flow? Or can we do the standard FakeQuant quantization?