Torch 2.0 compile not compatible with FX Graph Mode Quantization?

Hi!
I’ve been trying the new compile feature with the FX quantization. However it seems compile the qat prepared model gives shape mismatch error.

from torchvision.models import resnet18
import torch
from torch.ao.quantization.quantize_fx import prepare_qat_fx

net = resnet18().to(torch.float32).cuda()
net.train()
qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('x86')
example_input = (torch.randn((10, 3, 128, 128)),)
net = prepare_qat_fx(net, qconfig_mapping, example_input)

net = torch.compile(net)
example_input = torch.randn((10, 3, 128, 128)).cuda()
output = net(example_input)

Running this code give me error

  File "/home/test_demo.py", line 14, in <module>
    output = net(example_input)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "<eval_with_key>.2", line 4, in forward
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2819, in forward
    return compiled_fn(full_args)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1222, in g
    return f(*args)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2386, in debug_compiled_function
    return compiled_function(*args)
  File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1970, in runtime_wrapper
    original_inpt.copy_(updated_inpt)
RuntimeError: output with shape [1] doesn't match the broadcast shape [64]

if I comment out the following line then it runs fine.

qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('x86')
example_input = (torch.randn((10, 3, 128, 128)),)
net = prepare_qat_fx(net, qconfig_mapping, example_input)

I met similar problem.

I was tring to quantize a compiled model, for your example, code could be like:

from torchvision.models import resnet18
import torch
from torch.ao.quantization.quantize_fx import prepare_qat_fx

net = resnet18().to(torch.float32)
net.train()

example_input = torch.randn((10, 3, 128, 128))
net = torch.compile(net)
output = net(example_input)

qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('x86')
net = prepare_qat_fx(net, qconfig_mapping, example_input)

output = net(example_input)

and the error message:

Traceback (most recent call last):
  File "fx_demo.py", line 13, in <module>
    net = prepare_qat_fx(net, qconfig_mapping, example_input)
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/ao/quantization/quantize_fx.py", line 487, in prepare_qat_fx
    return _prepare_fx(
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/ao/quantization/quantize_fx.py", line 133, in _prepare_fx
    graph_module = GraphModule(model, tracer.trace(model))
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/fx/proxy.py", line 385, in __iter__
    return self.tracer.iter(self)
  File "/nfs1/fan.mo/00-pyenv/torch2.0/lib/python3.8/site-packages/torch/fx/proxy.py", line 285, in iter
    raise TraceError('Proxy object cannot be iterated. This can be '
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

Quantization support in torch.compile is still early days but if you’re curious you can see how it works here https://github.com/facebookexperimental/protoquant

1 Like

Hi @Archer_Z

I’m not sure if torch.compile will work with the prepared QAT model. I do think however you should be able to follow the QAT workflow and call torch.compile on the returned converted model. There’s a QAT tutorial here if it’s helpful.

cc @andrewor do you know if this is supported / in our roadmap?

Hi @Archer_Z, there are no plans right now to support torch.compile with FX graph mode quantization, so it may not work well out of the box. We do have plans to support a new PT2.0 QAT flow that is compatible with torch.compile, but that is still in the works and may not be ready until later this year.

3 Likes