Post-Training Quantization to Custom Bitwidth

hi @andrea.migliorati , could you also post how you are setting QConfig? I don’t see that part in your example.

To clarify, here is the part of @HDCharles 's snippet which deals with emulating a bidwidth:

#B is bits
B=4

##intB qconfig:
intB_act_fq=FakeQuantize.with_args(observer=HistogramObserver, quant_min=0, quant_max=2**B-1,  dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False)

intB_weight_fq=FakeQuantize.with_args(observer=HistogramObserver, quant_min=-(2**B)/2, quant_max=(2**B)/2-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)


intB_qconfig=QConfig(activation=intB_act_fq, weight=intB_weight_fq)

model_fp32.qconfig=intB_qconfig

You would need to apply this config to your model, and change the B parameter accordingly to test different bitwidths.

Yes, this is exactly the FakeQuantize configuration I am using, but I always get validation accuracy equal to the bitwidth B (8% for 8 bits, 4% for 4 bits, and so on)

it would be difficult to help without seeing the code of how specifically you are setting qconfig and calling the quantization APIs

I am literally using the same code in the HDCharles reply. The only difference is the network I wanna do tests with which is a resnet20. I changed the resnet20 code as explained in my previous reply where it’s also reported how I fuse the model. I don’t understand if the error is in the fusion or in the resnet20 changes I made.

the quant and dequant stubs define where the dtype changes from fp32 to int8 so you can have sections that you want quantized and those you don’t. However you also need to set the qconfig appropriately, any modules with the qconfig set will be quantized (and/or their children).

i.e. in my example I have the quant and dequant surround the entire model since I apply the qconfig to the top level of the model (which during the quantization flow will apply that same qconfig to the rest of the model)

if you have parts of the model that you don’t want quantized, you need to make sure those don’t have a qconfig and that those were surrounded by a dequant and quant.

if you can make a toy repro that shows your issue, it would be helpful, Vasiliy was saying that its unclear which modules in your example have qconfig set and which do not.

as a clear example, here is how i would handle the BasicBlocks where i always avoid quantizing shortcut:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization import FakeQuantize, HistogramObserver, MinMaxObserver, QConfig


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option="A"):
        super(BasicBlock, self).__init__()
        # **self.quant = torch.ao.quantization.QuantStub()** # I would put these in top level
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.q = torch.ao.quantization.QuantStub()
        self.dq = torch.ao.quantization.DeQuantStub()
        # eager mode quantization works poorly with functionals, need a module in order to do fusion
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == "A":
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)
                )
            elif option == "B":
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion * planes),
                )
        # **self.dequant = torch.ao.quantization.DeQuantStub()** # I would put these in top level

    def forward(self, x):
        # **out = self.quant(x)** # I would put these in top level
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.dq(out)
        out += self.shortcut(x)  # surround shortcut with dq, q
        out = self.q(out)
        out = self.relu2(out)
        # **out = self.dequant(out)** # I would put these in top level
        return out


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.bb1 = BasicBlock(1, 1, 1)
        self.bb2 = BasicBlock(1, 1, 1)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.bb1(x)
        x = self.bb2(x)
        x = self.dequant(x)
        return x


model_fp32 = Net().eval()

# B is bits
B = 4

##intB qconfig:
intB_act_fq = FakeQuantize.with_args(
    observer=HistogramObserver,
    quant_min=0,
    quant_max=2 ** B - 1,
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False,
)

intB_weight_fq = FakeQuantize.with_args(
    observer=MinMaxObserver,
    quant_min=-(2 ** B) // 2,
    quant_max=(2 ** B) // 2 - 1,
    dtype=torch.qint8,
    qscheme=torch.per_tensor_symmetric,
    reduce_range=False,
)


intB_qconfig = QConfig(activation=intB_act_fq, weight=intB_weight_fq)

# not sure what lambda layer is but i doubt we can quantize it
for subnet in [model_fp32.bb1, model_fp32.bb2]:
    for name, module in subnet.named_children():
        if "shortcut" not in name:
            module.qconfig = intB_qconfig

model_fp32.quant.qconfig = intB_qconfig
model_fp32.dequant.qconfig = intB_qconfig

to_fuse = [
    ["bb1.conv1", "bb1.bn1", "bb1.relu1"],
    ["bb1.conv2", "bb1.bn2"],
    ["bb2.conv1", "bb2.bn1", "bb2.relu1"],
    ["bb2.conv2", "bb2.bn2"],
]

model_fp32_fused = torch.ao.quantization.fuse_modules_qat(model_fp32, to_fuse).train()

model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused)

# calibrate model
model_fp32_prepared.apply(torch.ao.quantization.disable_fake_quant)
# calibration_code()
model_fp32_prepared(torch.randn(1, 1, 10, 10))

# prevents fake_quant from changing based on test code
model_fp32_prepared.apply(torch.ao.quantization.enable_fake_quant)
model_fp32_prepared.apply(torch.ao.quantization.disable_observer)

# test intB numbers
# test_code()
model_fp32_prepared(torch.randn(1, 1, 10, 10))

print(model_fp32_prepared)
1 Like

Hi @HDCharles, thanks for the help. I managed to understand where to put the quant/dequant stubs and to also do QAT on my model (which I was not interested in at first). For example, I can reach after a few 10s of QAT fine-tuning epochs a test accuracy that is very close to the original FP32 model with a INT8 configuration. In particular, I used the default torch.quantization.get_default_qconfig(“fbgemm”)torch.quantization.get_default_qconfig(“fbgemm”).

However, whenever I try to use a custom qconfig such as the FakeQuantize one you proposed, or also the following one:

intB_act_fq = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=int(2 ** B - 1), dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False)
intB_weight_fq = FakeQuantize.with_args(observer=PerChannelMinMaxObserver, quant_min=int(-(2 ** B) / 2), quant_max=int((2 ** B) / 2 - 1), dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False)
quantized_model.qconfig = QConfig(activation=intB_act_fq, weight=intB_weight_fq)

the QAT training does not converge anymore and instead stays at 10% evaluation accuracy independently of the value of bits B, even after a lot of epochs. Do you have any idea as to why this happens? Been trying to solve this for a while but no luck at all. Thanks

torch.quantization.get_default_qconfig(“fbgemm”) isn’t going to work with QAT, thats a PTQ qconfig, see the quantization docs: Quantization — PyTorch master documentation for the default setup or (beta) Quantized Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 1.12.0+cu102 documentation for a tutorial about QAT.

If you are using that qconfig, you are just doing normal gradient descent.

Oh ok, that’s why the fine-tuning was converging then.

Anyway, as I said I tried with the Qconfig FakeQuantize configuration I reported in my previous reply, as well as the one you mentioned in your last reply a few days ago, and the accuracy stays at random guess, even the loss stays at the same value at each fine-tuning epoch.

the problem is that this code wasn’t intended to work with QAT, its using parts of the QAT flow to do B bit PTQ (per the original question)

it should be noted that trying to use the above code to do QAT is moving in a bit of a circle. The QAT flow was first modified to do B bit PTQ, then almost all of those modifications would need to be undone to do B bit QAT.

Its easier to do B bit QAT starting from the normal QAT flow, the only difference is the QConfig which needs to have quant_min and quant_max set based on the bitwidth B

##intB QAT qconfig:
qat_qconfig=get_default_qat_qconfig()

intB_qat_qconfig = QConfig(
  activation=qat_qconfig.activation.with_args(
    quant_min=0, 
    quant_max=2**B-1
  )
  weight=qat_qconfig.activation.with_args(
    quant_min=-(2 ** B) // 2,
    quant_max=(2 ** B) // 2 - 1,
  )
)

if none of that solves the issue, I would first get normal QAT working and then try to get the above qconfig working for B=8 (since that should be the same as normal QAT) before moving on to lower bitwidths

1 Like

I’m not really sure I follow. As of now, I’m doing QAT fine-tuning as you explained with

quantized_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")

and the training converges to reasonable enough accuracies, so I guess I managed to have QAT work? Then, if I try with the following configuration:

intB_act_fq = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, 
                                     quant_min=0, quant_max=int(2 ** B - 1), dtype=torch.quint8, 
                                     qscheme=torch.per_tensor_affine, reduce_range=False)

intB_weight_fq = FakeQuantize.with_args(observer=PerChannelMinMaxObserver, 
                                        quant_min=int(-(2 ** B) / 2), quant_max=int((2 ** B) / 2 - 1),
                                        dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False)

quantized_model.qconfig = QConfig(activation=intB_act_fq, 
                                  weight=intB_weight_fq)

where B=8 bits, which should be equivalent to the previous (normal QAT), the training also converges to similar values. However, as soon as I go B < 8, the training doesn’t converge, both training and eval accuracy stays at 10%, and the training loss barely decreases. I’m at a loss here, would it be maybe necessary to load a B=8 checkpoint to do QAT on a B=7 one, and so on down to B = 2? Thanks

1 Like

The weight observer is wrong. If the weight is changing over time, minmaxobserver will work poorly. That’s why I’m saying it’s easier to just start from a normal qat setup and do the modification in my post rather than using the b bit ptq setup and trying to edit that to b bit qat.

As for checkpoints, that might help, you could also test whether the activation vs the weight observer is what’s causing the accuracy drop by relaxing one to is normal settings

Hi @HDCharles, again I’m not sure I follow. Am I not already in a normal qat setup? I’m doing exactly as presented here Quantization — PyTorch 1.12 documentation (QAT section), but instead of using

model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

I’m using the custom FakeQuantize QConfig as in my previous comment with B=2, 3, 4, 5, 6, 7. I don’t think it’s an Observer problem because the custom config with B=8 works perfectly, and even if it’d perform poorly with B=7 it’s difficult to think I would have such a drop down to random guess accuracy. I also tried with other observers like HistogramObserver but nothing changes really (except for the fact that HistogramObserver is incredibily, at least 100 times slower). I tried searching for any entries about custom bidwidth QAT but couldn’t find anything helpful. I also tried going below 8 bits alternatively just for weights and just for activations but still it doesn’t work, as soon as one of the QConfig settings goes <8 QAT doesn’t work for me.

@andrea.migliorati the weight observer you mentioned is a minmaxobserver which is not correct, it should be a movingaverageminmaxobserver. Use the setup above i.e. ##intB QAT qconfig. QAT changes the weights over time, if the weight observer is a minmaxobserver it will retain record the max and min of all weights its ever seen. If your weights change from initially being between -10 and 10 to finally being 0 to 10, you will be operating with twice the quantization error than you could be. with a moving average observer the min/max will gradually shift, fixing this problem.

It may not be the sole cause of your accuracy problem but there may be more than 1 cause or there may be a situation where the local optima for B=8 isn’t a local optima for B<8 so you run into the above problem where the weight ranges need to change significantly but the observer you have for weights makes that problematic.

You are correct that ideally going from B=8 to B=7 shouldn’t cause such a drastic drop in accuracy, (for the same reason that PTQ tends to give OK results which analogously goes from 32 to 8 bits).

without more code its hard to say what the issue could be, its always possible there’s a bug in the way we calculate fake_quant/calculate_qparams for these bitwidths but that should be easy to detect if you look at something like L2 error when the bit width changes.

@HDCharles

Hi,

I’m trying to do 4-bit QAT on Resnet18.

After training the network using the ##intB QAT qconfig above, I’m a bit confused on how to quantize the network. Will torch.quantization.convert(model) be sufficient, or do we need a custom mapping dict?

When I run validation after torch.quantization.convert(model), I get the following error: RuntimeError: Could not run 'quantized::conv2d_relu.new' with arguments from the 'UNKNOWN_TENSOR_TYPE_ID' backend. 'quantized::conv2d_relu.new' is only available for these backends: [QuantizedCPU].

Thanks!

You can’t, there are no int B quantized kernels to use unfortunately

to be clear, the original solution was intended to simulate B bit PTQ using parts of the QAT flow, then there was a followup about simulating QAT, but all of these solutions end at the point where you have a prepared model, ready to convert. At that point you can disable observers and run validation/test but you won’t be able to actually convert it to actual quantized kernels like you can for the normal quantization flow because no kernels exist as of now. Its possible those will be created in the future but thats more a question for the backend teams i.e. FBGEMM or QNNPACK or XNNPACK

@HDCharles Sure, you’re right, I’ll open another dedicated entry in the forum to avoid confusion. Can you post here nonetheless a link of how the normal QAT is and what’s the flow to follow? Thank you

all the quantization flows are described in the docs, here is the specific QAT link:

https://pytorch.org/docs/stable/quantization.html#quantization-aware-training-for-static-quantization

Thank you a lot for the help, it helped me in my progress.

For anyone reading this blog later I recommend this Practical Quantization in PyTorch | PyTorch for tips in practice.

            intB_qat_qconfig = torch.quantization.QConfig(
            activation = torch.quantization.MovingAverageMinMaxObserver.with_args(
                quant_min=0, 
                quant_max=2**bitwidth-1, 
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine, 
                reduce_range=True
            ),
            weight= torch.quantization.MinMaxObserver.with_args(
                quant_min=-(2 ** bitwidth) // 2,
                quant_max=(2 ** bitwidth) // 2 - 1,
                dtype=torch.qint8, 
                qscheme=torch.per_channel_symmetric, 
                reduce_range=False
            )
            )
            model_ft.qconfig = intB_qat_qconfig

this is what worked for me for <8bit quantization @andrea.migliorati if you still have same problem try it.

Hi Jerry, is it possible to simulate the INT4 Post-Training Quantization by inserting FakeQuantize?