Variable-bit (sub 8-bits) quantization for custom hardware deployment with power-of-two (pot) scales

Hey everyone! I am looking for a way to perform Quantization-Aware Training (QAT) using PyTorch.

My usecase concerns deploying trained PyTorch models on custom hardware (silicon) and so I have a few requirements:

  • Needs to support nn.Conv1d (as this is part of the network that I want to deploy)
  • Needs to support some form of batch-norm folding
  • Needs to have power-of-two scales (as this avoids integer divisions in hardware)
  • Preferably does not require me to redefine my model with quantized modules
  • Supports up-to-date Python and PyTorch versions

In my search, I checked out all possible quantization-aware-training frameworks I could find and made a list of them (see below). In brackets () indicates the last commit date and in square brackets [] it shows whether I tried it or not.

However, none of these options really work or have all the features that I need. Does anyone else have suggestions on what I can use/do?

1 Like

Hi @d0uwe,

We have a tutorial for QAT here that you might find helpful.

Thanks for the quick response! I have read this tutorial but it is unclear to me how I can quantize the weights lower than 8 bits (for example 5).

As far as I know, PyTorch 2.0 does not support quantized weight lower than 8 bits natively. But you can emulate it numerically with a customized observer.

For example, if you want to quantize weight to int4, you can try the following setting:

from torch.ao.quantization.observer import MinMaxObserver
custom_observer = MinMaxObserver(quant_min=-8, quant_max=7)

Thanks, that looks good! How would I then for example implement that that the scales of the quantizer are always a power of two? I.e a scale factor of 1.9 is rounded to 2 (2^1), a scale factor of 59 is rounded to 64 (2^6) etc.

You could create your own observer class and override this function (pytorch/observer.py at 5cc2e4d7c939852f6de6f8497dc89d311e333dce · pytorch/pytorch · GitHub) to calculate the scales with a power of two restriction.

So if I understand correctly, the scale and zero-point are currently not learned but calculated right?

For creating my own observer class, from which class would I inherit? From UniformQuantizationObserverBase (or any other observer) or from ObserverBase?

Then for a learned scale and zero-point, is the below the correct approach?

def quantize_to_closest_power_of_two(x: torch.Tensor):
   ...

   return quantized_value

class MyObserverClass(...):
   def __init__(...):
      ...
      
      self.scale = nn.Parameter(torch.ones(required_shape))
      self.zero_point = nn.Parameter(torch.ones(required_shape))

   @torch.jit.export
   def calculate_qparams(self):
        return quantize_to_closest_power_of_two(self.scale), self.zero_point

Then a related question: what if I also want to only allow the weights to be powers of two, next to the scales being powers of two? Do I create a new class with FakeQuantize as base and then in the forward pass I clamp the weights to the closest powers of two?

Hi @d0uwe @Vasiliy_Kuznetsov - did you ever figure this out? Just running into the same issue at the moment.

you can create a custom Observer/FakeQuantize class that makes sure weights are close to power of two I think

Hi Tobias, no, I didn’t figure out how to properly use the PyTorch APIs for this. I finally decided to use Brevitas because it offers a lot of flexibility.

However, since PoT weight quantization was not supported, I built that myself here as part of a mini-library to make the usage of Brevitas 10x easier.