A more or less ‘baked in’ version of this is quantization aware training where the quantization library simulates the quantized operator using fake quants. This is generally used for training but would seem to work for your purposes.
See Quantization Aware Training section here: Quantization — PyTorch master documentation
you could then adapt that to suit your altered bit values i.e. the tutorial recipe altered for int4 would be:
import torch
# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
return x
model_fp32 = M()
model_fp32.train()
## int8 qconfig:
# model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
######### note the above qconfig is equivalent to: ######################
##act_fq=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
##...
## weight_fq=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
##model_fp32.qconfig = QConfig(activation=act_fq, weight=weight_fq)
#B is bits
B=4
##intB qconfig:
intB_act_fq=FakeQuantize.with_args(observer=HistogramObserver, quant_min=0, quant_max=2**B-1, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False)
intB_weight_fq=FakeQuantize.with_args(observer=HistogramObserver, quant_min=-(2**B)/2, quant_max=(2**B)/2-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
intB_qconfig=QConfig(activation=intB_act_fq, weight=intB_weight_fq)
model_fp32.qconfig=intB_qconfig
model_fp32_fused = torch.quantization.fuse_modules(model_fp32,
[['conv', 'bn', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)
# calibrate model
model_fp32_prepared.apply(torch.ao.quantization.disable_fake_quant)
calibration_code()
#prevents fake_quant from changing based on test code
model_fp32_prepared.apply(torch.ao.quantization.enable_fake_quant)
model_fp32_prepared.apply(torch.ao.quantization.disable_observer)
# test intB numbers
test_code()
Some additional thoughts:
If you are trying to mimic hardware (which is what the flow is generally geared towards), you need the argument reduce_range=True (for certain backends) for the activation part in order to deal with overflow issues in hardware (which just removes one bit to avoid overflow, so the activations are actually ~quint7 and the weights are qint8). If you are just simulating numerics though that wouldn’t be needed. The above example sets reduce_range=False accordingly.
In QAT MovingAverageMinMaxObserver is used to gradually alter the range as the weights change which is probably not something you’re interested in. I replaced that with HistogramObserver which is what you’d use in normal quantization and is generally more accurate but slower, if thats a problem you could replace it with MinMaxObserver.
The ‘disable_fake_quant’ in the calibration code is there because normal quantization calibration doesn’t actually simulate the quantization numerics with fake_quant during calibration.
Let me know if you have any questions