XNNPACKQuantizer.set_module_name() not working as expected

Hello. I’m experimenting with PT2E PTQ on my object detection model. Since quantizing the whole model completely zeros its precisions, I decide to only quantize its backbone like below:

import argparse
from copy import deepcopy
from os import cpu_count
from pathlib import Path
from shutil import get_terminal_size
import sys
import warnings
import torch
from torch.export import export, export_for_training
from torch.ao.quantization import move_exported_model_to_eval
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from tqdm import tqdm
_ROOT = str(Path(__file__).resolve())
if not _ROOT in sys.path:
    sys.path.append(_ROOT)
from module.fastestdetv2 import FastestDetV2
from utils.config import Config
from utils.datasets import collate_fn, Dataset
from utils.evaluator import COCODetectionEvaluator
from utils.loss import DetectorLoss

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--weights", type=str, required=True, help=".pt weights, reparameterized")
    parser.add_argument("--weights", type=str, default="checkpoints/fastestdetv2_qamobileone_coco_ap50,0.000000_ep0.pth", help=".pt weights, reparameterized")
    parser.add_argument("--configs", type=str, default=str(Path(__file__).parent/"configs/coco.yaml"), help=".yaml configs")
    parser.add_argument("--target", type=str, default="arm", help="target platform, arm or x86")
    opt = parser.parse_args()
    cfg = Config(opt.configs)
    cfg_name = Path(opt.configs).stem
    savedir = Path(__file__).resolve().parent/"checkpoints"
    savedir.mkdir(exist_ok=True)
    ncols = get_terminal_size().columns
    warnings.filterwarnings("ignore", message=".*erase_node(.*) on an already erased node.*")
    # data loaders
    num_workers = max(4, cpu_count() // 4)
    calib_dataset = Dataset(cfg.train_txt, cfg.input_size, aug=False)
    val_dataset = Dataset(cfg.val_txt, cfg.input_size, aug=False)
    calib_loader = torch.utils.data.DataLoader(calib_dataset, cfg.batch_size,
        shuffle=True, collate_fn=collate_fn, drop_last=True,
        num_workers=num_workers, persistent_workers=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, cfg.batch_size,
        shuffle=False, collate_fn=collate_fn, drop_last=False,
        num_workers=num_workers, persistent_workers=True)
    # model
    model = FastestDetV2(cfg.num_classes,
        load_weights=True, inference_mode=True).eval()
    model.load_state_dict(torch.load(opt.weights))
    print(f"Loaded detector weights {opt.weights}")
    proj_name = f"{type(model).__name__.lower()}_{type(model.backbone).__name__.lower()}_{cfg_name}"
    # quantizer
    dummy_inputs = (torch.randn(cfg.batch_size, 3, cfg.input_size[1], cfg.input_size[0]),)
    model = export_for_training(model, dummy_inputs).module()
    # quantizer
    if opt.target == "x86":
        import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
        qconfig = xiq.get_default_x86_inductor_quantization_config()
        quantizer = xiq.X86InductorQuantizer()
        for n, m in model.named_modules():
            if n.startswith(("backbone")):
                quantizer.set_module_name_qconfig(n, qconfig)
    else:
        import torch.ao.quantization.quantizer.xnnpack_quantizer as xpq
        qconfig = xpq.get_symmetric_quantization_config()
        quantizer = xpq.XNNPACKQuantizer()
        # # 1) 3.7M -> ~940K
        # quantizer.set_global(qconfig)
        # 2) 3.7M -> 3.6M
        for n, m in model.named_modules():
            if n.startswith(("backbone")):
                quantizer.set_module_name(n, qconfig)
    model = prepare_pt2e(model, quantizer)
    # calibration
    criterion = DetectorLoss()
    move_exported_model_to_eval(model)
    print("Start calibrating")
    with torch.no_grad():
        pbar = tqdm(calib_loader, ncols=ncols)
        avg_iou, avg_obj, avg_cls, avg_total, = 0.0, 0.0, 0.0, 0.0
        for ib, (imgs, labels) in enumerate(pbar):
            if ib > 0: break # skip for testing
            imgs = imgs.float() / 255.0
            outputs = model(imgs)
            iou, obj, cls, total = criterion(outputs, labels)
            avg_iou += iou.item()
            avg_obj += obj.item()
            avg_cls += cls.item()
            avg_total += total.item()
            pbar.set_description(f"iou{avg_iou/(ib+1):.2f} obj{avg_obj/(ib+1):.2f} "
                f"cls{avg_cls/(ib+1):.2f} loss{avg_total/(ib+1):.2f}")
    model_quant = convert_pt2e(deepcopy(model), fold_quantize=True)
    model_quant = export(model_quant, dummy_inputs).module()
    # stats = COCODetectionEvaluator(cfg.names).eval(
    #     val_loader, model_quant, ncols=ncols, colour="green")
    stats = {"coco/AP50": 0.0} # skip for testing
    torch.save(model_quant.state_dict(), str(savedir/
        f"{proj_name}_ap50,{stats['coco/AP50']:.6f}_ptq,{opt.target}.pth"))

Quantization with X86InductorQuantizer worked really well, with the final model size reduced from 3.7M to 1.4M:

        for n, m in model.named_modules():
            if n.startswith(("backbone")):
                quantizer.set_module_name_qconfig(n, qconfig)

However, when using XNNPackQuantizer, .set_module_name()only reduces less than 0.1M of the model size, and the backbone’s state dict still mainly holds float32-typed values, while .set_global() reduces the model to about 940K (although I don’t want to quantize the detection head):

        for n, m in model.named_modules():
            if n.startswith(("backbone")):
                quantizer.set_module_name(n, qconfig)

I use PyTorch 2.6 since it’s the newest version supported by my lab workstation.

Am I doing this wrong, or it has been fixed by newer versions? Thanks in advance.

Update: I also tried quantizer.set_module_type(type(model.backbone), qconfig) and model size is still 3.6M. Quantization is still wrong for XNNPackQuantizer.

The weights structure of the original model:

backbone.stage1.0.bn.bias torch.float32
backbone.stage1.0.bn.num_batches_tracked torch.int64
backbone.stage1.0.bn.running_mean torch.float32
backbone.stage1.0.bn.running_var torch.float32
backbone.stage1.0.bn.weight torch.float32
backbone.stage1.0.reparam_conv.bias torch.float32
backbone.stage1.0.reparam_conv.weight torch.float32
backbone.stage2.0.bn.bias torch.float32
backbone.stage2.0.bn.num_batches_tracked torch.int64
backbone.stage2.0.bn.running_mean torch.float32
backbone.stage2.0.bn.running_var torch.float32
backbone.stage2.0.bn.weight torch.float32
backbone.stage2.0.reparam_conv.bias torch.float32
backbone.stage2.0.reparam_conv.weight torch.float32
backbone.stage2.2.bn.bias torch.float32
backbone.stage2.2.bn.num_batches_tracked torch.int64
backbone.stage2.2.bn.running_mean torch.float32
backbone.stage2.2.bn.running_var torch.float32
backbone.stage2.2.bn.weight torch.float32
backbone.stage2.2.reparam_conv.bias torch.float32
backbone.stage2.2.reparam_conv.weight torch.float32
backbone.stage3.0.bn.bias torch.float32
backbone.stage3.0.bn.num_batches_tracked torch.int64
backbone.stage3.0.bn.running_mean torch.float32
backbone.stage3.0.bn.running_var torch.float32
backbone.stage3.0.bn.weight torch.float32
backbone.stage3.0.reparam_conv.bias torch.float32
backbone.stage3.0.reparam_conv.weight torch.float32
backbone.stage3.2.bn.bias torch.float32
backbone.stage3.2.bn.num_batches_tracked torch.int64
backbone.stage3.2.bn.running_mean torch.float32
backbone.stage3.2.bn.running_var torch.float32
backbone.stage3.2.bn.weight torch.float32
backbone.stage3.2.reparam_conv.bias torch.float32
backbone.stage3.2.reparam_conv.weight torch.float32
backbone.stage3.4.bn.bias torch.float32
backbone.stage3.4.bn.num_batches_tracked torch.int64
backbone.stage3.4.bn.running_mean torch.float32
backbone.stage3.4.bn.running_var torch.float32
backbone.stage3.4.bn.weight torch.float32
backbone.stage3.4.reparam_conv.bias torch.float32
backbone.stage3.4.reparam_conv.weight torch.float32
backbone.stage4.0.bn.bias torch.float32
backbone.stage4.0.bn.num_batches_tracked torch.int64
backbone.stage4.0.bn.running_mean torch.float32
backbone.stage4.0.bn.running_var torch.float32
backbone.stage4.0.bn.weight torch.float32
backbone.stage4.0.reparam_conv.bias torch.float32
backbone.stage4.0.reparam_conv.weight torch.float32
backbone.stage4.2.bn.bias torch.float32
backbone.stage4.2.bn.num_batches_tracked torch.int64
backbone.stage4.2.bn.running_mean torch.float32
backbone.stage4.2.bn.running_var torch.float32
backbone.stage4.2.bn.weight torch.float32
backbone.stage4.2.reparam_conv.bias torch.float32
backbone.stage4.2.reparam_conv.weight torch.float32
backbone.stem.0.bn.bias torch.float32
backbone.stem.0.bn.num_batches_tracked torch.int64
backbone.stem.0.bn.running_mean torch.float32
backbone.stem.0.bn.running_var torch.float32
backbone.stem.0.bn.weight torch.float32
backbone.stem.0.reparam_conv.bias torch.float32
backbone.stem.0.reparam_conv.weight torch.float32
backbone.stem.2.bn.bias torch.float32
backbone.stem.2.bn.num_batches_tracked torch.int64
backbone.stem.2.bn.running_mean torch.float32
backbone.stem.2.bn.running_var torch.float32
backbone.stem.2.bn.weight torch.float32
backbone.stem.2.reparam_conv.bias torch.float32
backbone.stem.2.reparam_conv.weight torch.float32
det.cls_head.conv1x1.bn.bias torch.float32
det.cls_head.conv1x1.bn.num_batches_tracked torch.int64
det.cls_head.conv1x1.bn.running_mean torch.float32
det.cls_head.conv1x1.bn.running_var torch.float32
det.cls_head.conv1x1.bn.weight torch.float32
det.cls_head.conv1x1.reparam_conv.bias torch.float32
det.cls_head.conv1x1.reparam_conv.weight torch.float32
det.cls_head.conv5x5r.0.bn.bias torch.float32
det.cls_head.conv5x5r.0.bn.num_batches_tracked torch.int64
det.cls_head.conv5x5r.0.bn.running_mean torch.float32
det.cls_head.conv5x5r.0.bn.running_var torch.float32
det.cls_head.conv5x5r.0.bn.weight torch.float32
det.cls_head.conv5x5r.0.reparam_conv.bias torch.float32
det.cls_head.conv5x5r.0.reparam_conv.weight torch.float32
det.conv1x1r.0.bn.bias torch.float32
det.conv1x1r.0.bn.num_batches_tracked torch.int64
det.conv1x1r.0.bn.running_mean torch.float32
det.conv1x1r.0.bn.running_var torch.float32
det.conv1x1r.0.bn.weight torch.float32
det.conv1x1r.0.reparam_conv.bias torch.float32
det.conv1x1r.0.reparam_conv.weight torch.float32
det.conv3x3.bn.bias torch.float32
det.conv3x3.bn.num_batches_tracked torch.int64
det.conv3x3.bn.running_mean torch.float32
det.conv3x3.bn.running_var torch.float32
det.conv3x3.bn.weight torch.float32
det.conv3x3.reparam_conv.bias torch.float32
det.conv3x3.reparam_conv.weight torch.float32
det.obj_head.conv1x1.bn.bias torch.float32
det.obj_head.conv1x1.bn.num_batches_tracked torch.int64
det.obj_head.conv1x1.bn.running_mean torch.float32
det.obj_head.conv1x1.bn.running_var torch.float32
det.obj_head.conv1x1.bn.weight torch.float32
det.obj_head.conv1x1.reparam_conv.bias torch.float32
det.obj_head.conv1x1.reparam_conv.weight torch.float32
det.obj_head.conv5x5r.0.bn.bias torch.float32
det.obj_head.conv5x5r.0.bn.num_batches_tracked torch.int64
det.obj_head.conv5x5r.0.bn.running_mean torch.float32
det.obj_head.conv5x5r.0.bn.running_var torch.float32
det.obj_head.conv5x5r.0.bn.weight torch.float32
det.obj_head.conv5x5r.0.reparam_conv.bias torch.float32
det.obj_head.conv5x5r.0.reparam_conv.weight torch.float32
det.reg_head.conv1x1.bn.bias torch.float32
det.reg_head.conv1x1.bn.num_batches_tracked torch.int64
det.reg_head.conv1x1.bn.running_mean torch.float32
det.reg_head.conv1x1.bn.running_var torch.float32
det.reg_head.conv1x1.bn.weight torch.float32
det.reg_head.conv1x1.reparam_conv.bias torch.float32
det.reg_head.conv1x1.reparam_conv.weight torch.float32
det.reg_head.conv5x5r.0.bn.bias torch.float32
det.reg_head.conv5x5r.0.bn.num_batches_tracked torch.int64
det.reg_head.conv5x5r.0.bn.running_mean torch.float32
det.reg_head.conv5x5r.0.bn.running_var torch.float32
det.reg_head.conv5x5r.0.bn.weight torch.float32
det.reg_head.conv5x5r.0.reparam_conv.bias torch.float32
det.reg_head.conv5x5r.0.reparam_conv.weight torch.float32
spp.conv1x1.0.bn.bias torch.float32
spp.conv1x1.0.bn.num_batches_tracked torch.int64
spp.conv1x1.0.bn.running_mean torch.float32
spp.conv1x1.0.bn.running_var torch.float32
spp.conv1x1.0.bn.weight torch.float32
spp.conv1x1.0.reparam_conv.bias torch.float32
spp.conv1x1.0.reparam_conv.weight torch.float32
spp.out.bn.bias torch.float32
spp.out.bn.num_batches_tracked torch.int64
spp.out.bn.running_mean torch.float32
spp.out.bn.running_var torch.float32
spp.out.bn.weight torch.float32
spp.out.reparam_conv.bias torch.float32
spp.out.reparam_conv.weight torch.float32
spp.s1.1.bn.bias torch.float32
spp.s1.1.bn.num_batches_tracked torch.int64
spp.s1.1.bn.running_mean torch.float32
spp.s1.1.bn.running_var torch.float32
spp.s1.1.bn.weight torch.float32
spp.s1.1.reparam_conv.bias torch.float32
spp.s1.1.reparam_conv.weight torch.float32
spp.s1.3.bn.bias torch.float32
spp.s1.3.bn.num_batches_tracked torch.int64
spp.s1.3.bn.running_mean torch.float32
spp.s1.3.bn.running_var torch.float32
spp.s1.3.bn.weight torch.float32
spp.s1.3.reparam_conv.bias torch.float32
spp.s1.3.reparam_conv.weight torch.float32
spp.s2.1.bn.bias torch.float32
spp.s2.1.bn.num_batches_tracked torch.int64
spp.s2.1.bn.running_mean torch.float32
spp.s2.1.bn.running_var torch.float32
spp.s2.1.bn.weight torch.float32
spp.s2.1.reparam_conv.bias torch.float32
spp.s2.1.reparam_conv.weight torch.float32
spp.s2.3.bn.bias torch.float32
spp.s2.3.bn.num_batches_tracked torch.int64
spp.s2.3.bn.running_mean torch.float32
spp.s2.3.bn.running_var torch.float32
spp.s2.3.bn.weight torch.float32
spp.s2.3.reparam_conv.bias torch.float32
spp.s2.3.reparam_conv.weight torch.float32
spp.s3.1.bn.bias torch.float32
spp.s3.1.bn.num_batches_tracked torch.int64
spp.s3.1.bn.running_mean torch.float32
spp.s3.1.bn.running_var torch.float32
spp.s3.1.bn.weight torch.float32
spp.s3.1.reparam_conv.bias torch.float32
spp.s3.1.reparam_conv.weight torch.float32
spp.s3.3.bn.bias torch.float32
spp.s3.3.bn.num_batches_tracked torch.int64
spp.s3.3.bn.running_mean torch.float32
spp.s3.3.bn.running_var torch.float32
spp.s3.3.bn.weight torch.float32
spp.s3.3.reparam_conv.bias torch.float32
spp.s3.3.reparam_conv.weight torch.float32

XNNPackQuantizer-quantized model, with .set_module_name("backbone", qconfig), .set_module_type(type(model.backbone), qconfig) or for n, m in model.named_modules(): if n.startswith(("backbone")): quantizer.set_module_name(n, qconfig) (being the same):

backbone.stage1.0.reparam_conv.bias torch.float32
backbone.stage1.0.reparam_conv.weight torch.float32
backbone.stage2.0.reparam_conv.bias torch.float32
backbone.stage2.0.reparam_conv.weight torch.float32
backbone.stage2.2.reparam_conv.bias torch.float32
backbone.stage2.2.reparam_conv.weight torch.float32
backbone.stage3.0.reparam_conv.bias torch.float32
backbone.stage3.0.reparam_conv.weight torch.float32
backbone.stage3.2.reparam_conv.bias torch.float32
backbone.stage3.2.reparam_conv.weight torch.float32
backbone.stage3.4.reparam_conv.bias torch.float32
backbone.stage3.4.reparam_conv.weight torch.float32
backbone.stage4.0.reparam_conv.bias torch.float32
backbone.stage4.0.reparam_conv.weight torch.float32
backbone.stage4.2.reparam_conv.bias torch.float32
backbone.stage4.2.reparam_conv.weight torch.float32
backbone.stem.0.reparam_conv.bias torch.float32
backbone.stem.0.reparam_conv.weight torch.float32
backbone.stem.2.reparam_conv.bias torch.float32
backbone.stem.2.reparam_conv.weight torch.float32
det.cls_head.conv1x1.reparam_conv.bias torch.float32
det.cls_head.conv1x1.reparam_conv.weight torch.float32
det.cls_head.conv5x5r.0.reparam_conv.bias torch.float32
det.cls_head.conv5x5r.0.reparam_conv.weight torch.float32
det.conv1x1r.0.reparam_conv.bias torch.float32
det.conv1x1r.0.reparam_conv.weight torch.float32
det.conv3x3.reparam_conv.bias torch.float32
det.conv3x3.reparam_conv.weight torch.float32
det.obj_head.conv1x1.reparam_conv.bias torch.float32
det.obj_head.conv1x1.reparam_conv.weight torch.float32
det.obj_head.conv5x5r.0.reparam_conv.bias torch.float32
det.obj_head.conv5x5r.0.reparam_conv.weight torch.float32
det.reg_head.conv1x1.reparam_conv.bias torch.float32
det.reg_head.conv1x1.reparam_conv.weight torch.float32
det.reg_head.conv5x5r.0.reparam_conv.bias torch.float32
det.reg_head.conv5x5r.0.reparam_conv.weight torch.float32
spp.conv1x1.0.reparam_conv.bias torch.float32
spp.conv1x1.0.reparam_conv.weight torch.float32
spp.out.reparam_conv.bias torch.float32
spp.out.reparam_conv.weight torch.float32
spp.s1.1.reparam_conv.bias torch.float32
spp.s1.1.reparam_conv.weight torch.float32
spp.s1.3.reparam_conv.bias torch.float32
spp.s1.3.reparam_conv.weight torch.float32
spp.s2.1.reparam_conv.bias torch.float32
spp.s2.1.reparam_conv.weight torch.float32
spp.s2.3.reparam_conv.bias torch.float32
spp.s2.3.reparam_conv.weight torch.float32
spp.s3.1.reparam_conv.bias torch.float32
spp.s3.1.reparam_conv.weight torch.float32
spp.s3.3.reparam_conv.bias torch.float32
spp.s3.3.reparam_conv.weight torch.float32

XNNPackQuantizer-quantized model with .set_global(qconfig):

_frozen_param0 torch.int8
_frozen_param1 torch.int8
_frozen_param2 torch.int8
_frozen_param3 torch.int8
_frozen_param4 torch.int8
_frozen_param5 torch.int8
_frozen_param6 torch.int8
_frozen_param7 torch.int8
_frozen_param8 torch.int8
_frozen_param9 torch.int8
_frozen_param10 torch.int8
_frozen_param11 torch.int8
_frozen_param12 torch.int8
_frozen_param13 torch.int8
_frozen_param14 torch.int8
_frozen_param15 torch.int8
_frozen_param16 torch.int8
_frozen_param17 torch.int8
_frozen_param18 torch.int8
_frozen_param19 torch.int8
_frozen_param20 torch.int8
_frozen_param21 torch.int8
_frozen_param22 torch.int8
_frozen_param23 torch.int8
_frozen_param24 torch.int8
_frozen_param25 torch.int8
backbone.stage1.0.reparam_conv.bias torch.float32
backbone.stage2.0.reparam_conv.bias torch.float32
backbone.stage2.2.reparam_conv.bias torch.float32
backbone.stage3.0.reparam_conv.bias torch.float32
backbone.stage3.2.reparam_conv.bias torch.float32
backbone.stage3.4.reparam_conv.bias torch.float32
backbone.stage4.0.reparam_conv.bias torch.float32
backbone.stage4.2.reparam_conv.bias torch.float32
backbone.stem.0.reparam_conv.bias torch.float32
backbone.stem.2.reparam_conv.bias torch.float32
det.cls_head.conv1x1.reparam_conv.bias torch.float32
det.cls_head.conv5x5r.0.reparam_conv.bias torch.float32
det.conv1x1r.0.reparam_conv.bias torch.float32
det.conv3x3.reparam_conv.bias torch.float32
det.obj_head.conv1x1.reparam_conv.bias torch.float32
det.obj_head.conv5x5r.0.reparam_conv.bias torch.float32
det.reg_head.conv1x1.reparam_conv.bias torch.float32
det.reg_head.conv5x5r.0.reparam_conv.bias torch.float32
spp.conv1x1.0.reparam_conv.bias torch.float32
spp.out.reparam_conv.bias torch.float32
spp.s1.1.reparam_conv.bias torch.float32
spp.s1.3.reparam_conv.bias torch.float32
spp.s2.1.reparam_conv.bias torch.float32
spp.s2.3.reparam_conv.bias torch.float32
spp.s3.1.reparam_conv.bias torch.float32
spp.s3.3.reparam_conv.bias torch.float32