Using quantizable model for normal training

I want to integrate quantization aware training (QAT) as an option into my training code. Is there anything that speaks against preparing models for QAT and training them normally, i.e. without QAT?

With preparation I mean adding QuantStub, FloatFunctional, and using already fused modules such as ConvReLU2d. As far as I dived into the source these modules only serve as placeholders for fake quantization modules when actually converting model for QAT and do not affect training.

Question comes up because I saw that torchvision includes two versions of models, a normal and a quantizable version. If not depending on backward compatibility this appears unnecessary for my own code. Is there something I oversee?

Hi, by “quantizable” I suppose you mean the ones in torchvision specifically. These are needed for eager mode quantization, since the user needs to manually insert QuantStubs and DeQuantStubs like you mentioned. For QAT, these will be replaced by FakeQuantizes, which actually do change the numerics of training. That’s the point of QAT in the first place, which is to improve the accuracy of quantization by making the training process “aware” that the model will ultimately be quantized.

So my recommendation is the following. Either make “quantizable” versions of your model similar to torchvision, which uses eager mode quantization, or switch to FX graph mode quantization, where you don’t have to change a thing about your model and it’ll still be quantized automatically (with FakeQuantizes inserted for the QAT case). You can learn more about FX graph mode quantization here: (prototype) FX Graph Mode Quantization User Guide — PyTorch Tutorials 2.0.1+cu117 documentation. Please feel free to let me know if there’s anything else I can clarify.

Best,
-Andrew

Hi Andrew,
as my model structures are a bit complicated, I need to use eager quantization. What I meant with “making models quantizable” is including all necessary modules for QAT or PTQ already, whether using it or not.

For example, instead of defining my model as:

class Model(nn.Module):

    def __init__(self, nc_in: int=3):
        super().__init__()
        self.conv = nn.Conv2d(nc_in, 64, 3, 1, 1, bias=False)
        self.bnorm = nn.BatchNorm2d(64, affine=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.relu(self.bnorm(self.conv(x)))

and then fusing it, adding QuantStub/DeQuantstub etc. later or even maintain an additional model only for QAT, I could simply define my model as

from torch.ao.nn.intrinsic import ConvBnReLU2d
from torch.ao.nn.quantized import FloatFunctional
class QuantizableModel(nn.Module):

    def __init__(self, nc_in: int=3):
        super().__init__()
        self.quant = QuantStub()
        self.conv_bnorm_relu = ConvBnReLU2d(
                nn.Conv2d(nc_in, 64, 3, 1, 1, bias=False),
                nn.BatchNorm2d(64, affine=True),
                nn.ReLU(inplace=True),
            )
        self.add = FloatFunctional()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.add.add(x, self.conv_bnorm_relu (x))
        x = self.dequant(x)
        return x

and use it for training whether using QAT or not.
Thus, I would not have to deal with fusing modules or adding QuantStubs as everything is already specifically defined for this type of model. All that was required would be adding a qconfig and calling prepare_qat

if use_qat:
    model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
    torch.quantization.prepare_qat(model, inplace=True)

And in case of not using QAT nothing would change compared to the previous model.

Well, the latter is my actual question as I was uncertain whether there are any pitfalls, any other changes conducted in torch.ao.quantization.fuse_modules_qat, or in the fused placeholders that might be an obstacle

I think Conv - bn fusion might introduce some slight numerical mismatches, but I don’t see it affecting the model accuracy in general.

This is true if the modules are actually fused.
But as far as I understand it, the ConvBnReLU2d module and all other modules from torch.ao.nn.intrinsic only act as placeholders for the actual fused modules. If you look at its source code and the source of its base class _FusedModule, you see that it’s just a Sequential module calling one after another.

https://pytorch.org/docs/stable/_modules/torch/ao/nn/intrinsic/modules/fused.html#ConvBnReLU2d

The actual fusion of modules for QAT seems to happen in torch.quantization.prepare_qat, replacing the placeholders with the corresponding fused qat modules.

oh yes, that’s correct I think. fuse_modules won’t change the numerics