As mentioned in Quantization — PyTorch 2.6 documentation , supports 3 mode: Eager / Fx / Pytorch2.
And in Eager and Fx mode, the demo shows pipeline for ptq → qat, it’s useful to get bertter precision.
import torch
from torch.ao.quantization import (
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy
model_fp = UserModel()
#
# post training dynamic/weight_only quantization
#
# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# post training static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# quantization aware training for static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
Quantization — PyTorch 2.6 documentation
But in PyTorch2 ,the api changes to prepare_pt2e
and prepare_qat_pt2e
, i think they are different , so how to get a pipeline to do this?
In my imagine, it’s like:
float_model(Python) Example Input
\ /
\ /
—-------------------------------------------------------
| export |
—-------------------------------------------------------
|
FX Graph in ATen Backend Specific Quantizer
| /
—--------------------------------------------------------
| prepare_pt2e |
—--------------------------------------------------------
|
Calibrate
|
—--------------------------------------------------------
| convert back to Fx ? |
—--------------------------------------------------------
|
FX Graph?
|
—--------------------------------------------------------
| prepare_qat_pt2e |
—--------------------------------------------------------
|
Calibrate
|
—--------------------------------------------------------
| convert_pt2e |
—--------------------------------------------------------
|
Quantized Model
|
—--------------------------------------------------------
| Lowering |
—--------------------------------------------------------
|
Executorch, Inductor or <Other Backends>
thanks~