BatchNorm and ConvTranspose Fusion for QAT with FX Graph Mode

Description

I’m trying to quantize a model for training using FX Graph Mode. (ultimately, I want to run it with int8 precision using TensorRT, but that’s not the issue for now). Tracing works fine, the problem is during the fusion stage. My model uses BatchNorm and ConvTranspose modules, for which fusion is not yet supported for QAT. I tried just skipping the fusion of these specific modules, but without success. I’m not familiar enough with this part of PyTorch. My questions are:

  • Can I skip the BatchNorm-ConvTranspose fusion and continue with QAT? How would I do that?
  • Do I need to fuse in general before finetuning/QAT? Or can I get away with fusing after QAT?

I’m running:

qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
custom_config = PrepareCustomConfig()
model_prepared = prepare_qat_fx(model, qconfig_mapping, example_input, custom_config)

The error that I get:

model_prepared = prepare_qat_fx(model, qconfig_mapping, example_input, custom_config)
  File "\venv\lib\site-packages\torch\ao\quantization\quantize_fx.py", line 489, in prepare_qat_fx
    return _prepare_fx(
  File "\venv\lib\site-packages\torch\ao\quantization\quantize_fx.py", line 139, in _prepare_fx
    graph_module = _fuse_fx(
  File "\venv\lib\site-packages\torch\ao\quantization\quantize_fx.py", line 91, in _fuse_fx
    return fuse(
  File "\venv\lib\site-packages\torch\ao\quantization\fx\fuse.py", line 112, in fuse
    env[node.name] = obj.fuse(
  File "\venv\lib\site-packages\torch\ao\quantization\fx\fuse_handler.py", line 102, in fuse
    fused_module = fuser_method(is_qat, *matched_modules)
  File "\venv\lib\site-packages\torch\ao\quantization\fuser_method_mappings.py", line 195, in reversed
    return f(is_qat, y, x)
  File "\venv\lib\site-packages\torch\ao\quantization\fuser_method_mappings.py", line 149, in fuse_convtranspose_bn
    raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.")
Exception: Fusing ConvTranspose+BatchNorm not yet supported in QAT.

Thanks!

Environment

Operating System: Windows
Python Version: 3.9.12
PyTorch Version: 2.1.1+cu121

Steps To Reproduce

I’m happy to add an example model if it is needed/helpful.

fx graph mode quantization is in maintainence mode currently, I think you can start with our new flow: Quantization — PyTorch main documentation, this flow does not come with conv transpose - bn fusion in QAT currently

1 Like

It seems that I need to switch to a Linux system for this, but that should be doable. Thanks!
Do you suspect any difficulties in exporting to ONNX and using TensorRT with this approach?

  File "\venv\lib\site-packages\torch\_export\__init__.py", line 145, in capture_pre_autograd_graph
    m = torch._dynamo.export(
  File "\venv\lib\site-packages\torch\_dynamo\eval_frame.py", line 1031, in inner
    check_if_dynamo_supported()
  File "\venv\lib\site-packages\torch\_dynamo\eval_frame.py", line 535, in check_if_dynamo_supported
    raise RuntimeError("Windows not yet supported for torch.compile")
RuntimeError: Windows not yet supported for torch.compile

yeah windows is not supported…

Do you suspect any difficulties in exporting to ONNX and using TensorRT with this approach?

I don’t see there will be unfixable issues, but I haven’t tried this before