Quantizing Transformer Architecture Below 8-bit (post training quantization)

I’m trying to quantize BERT to 4 bits or mixed precision, and I don’t see available methods to to quantization aware training on BERT for any precision other than torch.uint8. This is given in the dynamic quantization tutorial.
I want to use both post training quantization and dynamic quantization for lower than 8 bits.

Will I have to rewrite the modeling_bert.py (transformers/modeling_bert.py) layers with fake quantization added? How can lower than 8bit precision and mixed precision be implemented on BERT?

The difficulty there is PyTorch inherently assumes that things are at least 1 byte when doing things with memory.
I’d probably convert to TVM and see what can be done there.
(QAT with fake quantization probably could work for 4 bits, too.)

It’s not an issue even if the weights are stored as FP32 values in memory.
I’m trying to evaluate post training quantization or fine tune the model with quantization aware training, but do this all under under fake quantization to any bit width of my choosing.

While I don’t think it works out of the box, you could try to adapt the observers and fake quant layers to be more flexible. For example, there are some obvious 8 bit hard coded values here:

we do have the support for lower bits in https://github.com/pytorch/pytorch/blob/master/torch/quantization/observer.py#L185 now, one of our interns just added this recently.

1 Like