Starting out with QAT

Hi all, I’m fairly new to model optimization and I’ve tried ONNX PTQ methods. However, I am required to explore QAT for YOLO pytorch models and I’m not sure what to start with.

Should I use Eager Mode or FX Graph Mode Quantization?
Which of them is easier and more general to different models?

Thanks in advance!

What is the difference between using torch.quantization.quantize_qat and following the tutorials available? i.e. fusing modules, adding quant and dequant etc

Hi @MrOCW , eager mode quantization is manual as in you would have to change the modeling code do add quants/dequants and specify fusions. FX graph mode quantization is automatic but it requires the model to be symbolically traceable.

Usually for new models I’d recommend trying FX graph mode quantization first. If the model is not symbolically traceable, then you would have to either make it symbolically traceable (difficulty depends on the model), or use eager mode.

I remember looking at yolov3 a few months ago and that model was challenging to make symbolically traceable.

Guess I’ll work with eager for now since it seems to be more general? just that more work has to be done. However, I am not too sure about where to add the Quant and DeQuant modules? Currently working with a model that has Conv2d-BatchNorm2d-SiLU and since SiLU is not supported… do i fuse conv+bn then add a quant and dequant for every other instance of such module? for e.g. Quant → ConvBn → DeQuant → SiLU ===> Quant → ConvBn → DeQuant → SiLU ===> …

yes, quant/dequant control which areas of the model you want to be in which dtype (torch.float vs torch.quint8).

Quant → ConvBn → DeQuant → SiLU ===> Quant → ConvBn → DeQuant → SiLU

yep, that sounds right. There is some example code here (Quantization — PyTorch 1.9.1 documentation) with similar toy examples.

For my case, I just used Quant → model - > Dequant with SiLUs being in the model and it still trains but with terrible accuracy. Is the SiLU being included in the Quant Dequant the main reason?

QAT accuracy could depend on many things, including when do you turn on fake_quant and how many batches you train the model, are you training from scratch or from pretrained weights?

Also could you paste the modified model here?

@jerryzh168
Heres a YOLOX_small model prepared for QAT.
the YOLOX small pretrained weights were loaded first, then i applied fuse_modules with [“conv”,“bn”] and then prepare_qat

QAT model
https://drive.google.com/file/d/1rUtNvHKzkR5mum3n9KHSwe11Ym-oQNR6/view?usp=sharing

Converted model
https://drive.google.com/file/d/1q_OOxc3jmdowLXXU9P7bANYHRbzVYJAi/view?usp=sharing

the structure of prepared and converted model looks OK I think. how do you do QAT? are you turning on the fakequant in the very beginning? it’s typically suggested to turn on fakequant after a few batches so that the observer can be populated with correct stats first I think

Sorry how do i control that?
Currently, I am doing

Quant → fused model → Dequant

and then just running the training loop without any edits.

Which way is recommended?
1.Train without QAT, load the trained weights, fused and quant dequant, then repeat training
2.Start QAT on my custom data right from the official pretrained weights

What are some hyperparameters I should take note of when performing QAT? (eg. epochs, learning rate, etc)

OK I tried by training without QAT, then load the pretrained model and train with QAT for 1 epoch and the mAP is close to the non QAT model. Is this workflow supposed to be the “correct” way?
Am I not supposed to train the QAT from scratch since doing so doesnt increase my mAP from 0 at all

the flow should be:

1. prepared = prepare_qat(model, ...)
2. disable fake_quant, but enable observation, prepared.apply(torch.ao.quantization.disable_fake_quant)
prepared.apply(torch.ao.quantization.enable_observer)
3. run a few epochs
4. enable fake_quant, and do QAT
prepared.apply(torch.ao.quantization.disable_fake_quant)
5. train a few epochs
6. convert to quantized model

how many epochs do we train before turning on fake_quant and after turning on fake_quant are the hyperparameters that we can experiment with.