Here’s a recipe from Practical Quantization in PyTorch | PyTorch
# QAT follows the same steps as PTQ, with the exception of the training loop before you actually convert the model to its quantized version
import torch
from torch import nn
backend = "fbgemm" # running on a x86 CPU. Use "qnnpack" if running on ARM.
m = nn.Sequential(
nn.Conv2d(2,64,8),
nn.ReLU(),
nn.Conv2d(64, 128, 8),
nn.ReLU()
)
"""Fuse"""
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair
"""Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(),
*m,
torch.quantization.DeQuantStub())
"""Prepare"""
m.train()
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare_qat(m, inplace=True)
The prepare_qat method propogates qconfig to all the leaf nodes, runs convert which swaps modules and then calls the prepare method which add observers to the modules.
In the code above, quant and dequant stubs are added to the model. How are they not altered by the convert call in prepare_qat?
See below for prepare_qat implementation from pytorch/quantize.py at 748d9d24940cd17938df963456c90fa1a13f3932 · pytorch/pytorch · GitHub
def prepare_qat(model, mapping=None, inplace=False, allow_list=None):
r"""
Prepares a copy of the model for quantization calibration or
quantization-aware training and converts it to quantized version.
Quantization configuration should be assigned preemptively
to individual submodules in `.qconfig` attribute.
Args:
model: input model to be modified in-place
mapping: dictionary that maps float modules to quantized modules to be
replaced.
inplace: carry out model transformations in-place, the original module
is mutated
allow_list: a set that lists out allowable modules to be propagated with qconfig
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
if mapping is None:
mapping = get_default_qat_module_mappings()
if not inplace:
model = copy.deepcopy(model)
propagate_qconfig_(model, qconfig_dict=None, allow_list=allow_list)
convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
return model
Thank you!