Hi!
I’ve been trying the new compile feature with the FX quantization. However it seems compile the qat prepared model gives shape mismatch error.
from torchvision.models import resnet18
import torch
from torch.ao.quantization.quantize_fx import prepare_qat_fx
net = resnet18().to(torch.float32).cuda()
net.train()
qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('x86')
example_input = (torch.randn((10, 3, 128, 128)),)
net = prepare_qat_fx(net, qconfig_mapping, example_input)
net = torch.compile(net)
example_input = torch.randn((10, 3, 128, 128)).cuda()
output = net(example_input)
Running this code give me error
File "/home/test_demo.py", line 14, in <module>
output = net(example_input)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "<eval_with_key>.2", line 4, in forward
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2819, in forward
return compiled_fn(full_args)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1222, in g
return f(*args)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2386, in debug_compiled_function
return compiled_function(*args)
File "/home/.conda/envs/wzpy39/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1970, in runtime_wrapper
original_inpt.copy_(updated_inpt)
RuntimeError: output with shape [1] doesn't match the broadcast shape [64]
if I comment out the following line then it runs fine.
qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('x86')
example_input = (torch.randn((10, 3, 128, 128)),)
net = prepare_qat_fx(net, qconfig_mapping, example_input)