How to quantize a pre-trained model to float16

I can successfully convert resnet18 to int8 with ptsq in eager mode. However, when I try to quantize to float16 and change the qconfig to torch.quantization.float16_static_qconfig, it meets problem:
Traceback (most recent call last):
File “/home/itl/Documents/xrh/backdoor/trojanzoo/examples/”, line 42, in
torch.quantization.convert(defense.attack.model._model, inplace=True)
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/ao/quantization/”, line 505, in convert
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/ao/quantization/”, line 541, in _convert
_convert(mod, mapping, True, # inplace
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/ao/quantization/”, line 543, in _convert
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/ao/quantization/”, line 568, in swap_module
new_mod = mapping[type(mod)].from_float(mod)
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/nn/quantized/modules/”, line 467, in from_float
return _ConvNd.from_float(cls, mod)
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/nn/quantized/modules/”, line 238, in from_float
return cls.get_qconv(mod, activation_post_process, weight_post_process)
File “/home/itl/anaconda3/envs/trojanzoo/lib/python3.10/site-packages/torch/nn/quantized/modules/”, line 198, in get_qconv
assert weight_post_process.dtype == torch.qint8,
AssertionError: Weight observer must have a dtype of qint8

model = ....
model = model.half()

And the parameters are turned to float16. Some layers are unable to do so.

1 Like

fp16_static quantization is not really supported in native quantized backend (fbgemm/qnnpack) actually, we previously added it for reference quantized model support in fx graph mode quantization. if that is what you want you can use this qconfig and then set is_reference to True:

m = prepare_fx(m, {"": float16_static_qconfig}, example_inputs)
m = convert_fx(m, is_refernece=True)

When I use fx graph model in my framework, it raise TraceError, how can I solve this error:
('Proxy object cannot be iterated. This can be ’
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

please checkout: (prototype) FX Graph Mode Quantization User Guide — PyTorch Tutorials 1.11.0+cu102 documentation