Qat: int4: first layer precision for int4 model

workflow for the qat now is:
using the same precision in each fake_quant for EVERY LAYER.
fp32 → fake_quant → fp32

problem i meet:
input data may be 8bit in most common cases.
when qat int4 model, first layer fake_quant “8bit data into 4bit” (or we call cut the data spread).
in this process we lost too much(precision drop happens in the input data) …
IF we can treat first layer with 8bit qconfig, and treat other layer with 4bit qconfig.
we can keep some more necessary input data.

is there any doc to use 2 or more qconfig in the same qat process.

ive notice Add MKLDNN quantization backend by Xia-Weiwen · Pull Request #67177 · pytorch/pytorch · GitHub mkldnn
but its a bit different, mkldnn just change internal compute logic,
i just wanna to add a new backend.
i find some reference here: Extending PyTorch Quantization to Custom Backends · pytorch/pytorch Wiki · GitHub
Is there any suggestion on develop a new backend, especially for qat.

Purpose for me:

  1. im working to develop a new int4 qat qconfig(or we say a new int4 backend) for a specific dla,
    in my opinion using 4bit in all layers may cause precision drop, especially for the first layer.
    Im try to deal with the problem in the first layer to keep as more information as possible to prevent precision drop.
  2. also, im try to find some use case/demo on how to use hybrid quant schemem, for example using 8bit qconfig and fp16 qconfig in the same qat process. any user interface.
  3. im searching for some hybrid quant qat demo, do u have some?
  4. any suggestion on develop a new int4 qat qconfig (or we say a new int4 backend).

Yeah, I would recommend using FX Graph Mode Quantization for this. We have post training quantization tutorial here: (prototype) FX Graph Mode Post Training Static Quantization — PyTorch Tutorials 1.10.0+cu102 documentation (we might add a QAT tutorial later). You can use prepare_qat_fx and use the qconfig_dict api to do this.

We do have a quint4x2 dtype currently: pytorch/ at master · pytorch/pytorch · GitHub, although I think this is mostly for weight. To support this with activation, I think you need:

  1. Add support for quint4x2 in quantize_per_tensor pytorch/QTensor.cpp at master · pytorch/pytorch · GitHub
  2. Use the is_reference option during convert_fx, which will produce a model with (dequant - float_op - quant) patterns representing the model (you can take a look at Extending PyTorch Quantization to Custom Backends · pytorch/pytorch Wiki · GitHub for reasons)
  3. lower the model to the dla you are building, this can be through fx/torchscript or any ways you prefer

Is there any guide for hybrid-quant for qat, mix int8 and int4 when training.

It would be better for use int8 in first and last layer, and use int4 in the inner layer.

first layer with int8 may prevent source data to be losted.
last layer with int8 may help some other process after inference (like video output, other accelerator).