Quantization aware training <8 bits simulation

Hello I am trying to simulate quantization aware training based on custom bit-width, I realized that based on the model I am using sometimes I have difficulty to make the model converge for certain bit-width.

For resnet18 the model converge for 8, 7, 6, 5. Once I go to 4bits the error value still the same approximately even for more than 100 epochs, anyone have insights on that so I can know how to tackle this issue.

Thank you for your ideas.

maybe the bitwidth is just too low for convergence? maybe you can try other ways of quantization e.g per channel quantization so that more information is preserved

yes I did some tests and it seems that the when I use a low bitwidth for the activations the model wont converge, for example if I am using bitwidth = 2 for weights and 8 for activations it works fine. But once I go for 4 in activations my model wont converge. I wonder whats the relation between those 2 maybe I can find a way to optimize both as much as I can, since later it will be good to optimize activations as well for my problem.

Concerning the per channel quantization I am already doing that.

I think it is common to use some more complicated techniques to compensate for the loss of accuracy in lower bits, for example, adaptive power of two quantization: pytorch/torch/ao/quantization/experimental at master · pytorch/pytorch · GitHub, [1909.13144] Additive Powers-of-Two Quantization: An Efficient Non-uniform Discretization for Neural Networks

1 Like

Can you share the code on how you simulate a bit width lower than 8?

Assuming you know how to do normal QAT with pytorch the main difference will be in your configuration you need to do this:

activation_bitwidth = 8 #whatever bit you want
bitwidth = 4 #whatever bit you want

fq_activation = torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MovingAverageMinMaxObserver.with_args(

fq_weights = torch.quantization.FakeQuantize.with_args(
            observer = torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args(
            quant_min=-(2 ** bitwidth) // 2,
            quant_max=(2 ** bitwidth) // 2 - 1,

intB_qat_qconfig = torch.quantization.QConfig(activation= fq_activation,weight = fq_weights)
model_ft.qconfig = intB_qat_qconfig

In this way you simulate ranges but you can’t use the .convert method later so its just simulation to see how your model will react to this quantization if you need any help let me know.
more info on using fake quantization here FakeQuant
nb: you can chose any configuration it fits you the example given is in my case.