QAT - how to handle nn.Parameter

Hello, I’m trying to train model using QAT technique but I’m stuck with nn.Parameter and I can’t find any solution. In the model that I want to quantize there are parameters that should to be “learned”. I prepared code snippet that shows my problem:

import torch
import torch.nn as nn


class test_module(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=10,
            kernel_size=1,
        )
        self.beta = nn.Parameter(torch.zeros((1, 10, 1, 1)), requires_grad=True)
        self.skip_mul = nn.quantized.FloatFunctional()

    def forward(self, x):
        x = self.conv1(x)
        x = self.skip_mul.mul(x, self.beta)


class qat_model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.quant = torch.quantization.QuantStub()
        self.conv = test_module()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)

        return x


model = qat_model()
model.eval()
model.to("cpu")
dummy_input = torch.rand(1, 3, 256, 256)

model.qconfig = torch.quantization.get_default_qat_qconfig("qnnpack")
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)

out = model(dummy_input)
print("quant out", out)

And I got error: RuntimeError: Mul operands should have same data type..
x is dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0 while self.beta remains torch.float32.

I would be grateful if someone could show me what is the correct solution to such a problem.

if you are doing QAT you don’t convert first, you only do preopare then you do the actual QAT training/backprop stuff then once you’re done, you convert.

Ok, but that’s not what the question was about. In this topic I’m asking how to handle nn.Parameter, which will not automatically change after conversion. It will still work on float32 instead of quint8.

Sorry, maybe the topic or snippet were misleading.

there are multiple issues, you are using prepare and convert but with a qat qconfig, compare to the QAT snippet in Quantization — PyTorch 2.2 documentation

You seem to have it setup to use the floatfunctional correctly so its just a case of getting the rest of the flow setup correctly.

I still don’t know if the nn.Parameter should be automatically converted to quint8 or I should manually do something about it - which is crux of my problem.

After correction, the error still occurs. Now this is almost the same as in the example you’ve sent except for no fusing because as the docstring says:

    Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, relu
    bn, relu

and I have no combination of those layers in my model.

After correction the snippet looks like this:

import torch.nn as nn


class test_module(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=10,
            kernel_size=1,
        )
        self.beta = nn.Parameter(torch.zeros((1, 10, 1, 1)), requires_grad=True)
        self.skip_mul = nn.quantized.FloatFunctional()

    def forward(self, x):
        x = self.conv1(x)
        x = self.skip_mul.mul(x, self.beta)


class M(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.quant = torch.quantization.QuantStub()
        self.conv = test_module()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)

        return x


model_fp32 = M()
model_fp32.eval()

model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32.train())
# training_loop(model_fp32_prepared)
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

input_fp32 = torch.rand(1, 3, 256, 256)
out = model_int8(input_fp32)
print("quant out:\n", out)

normally you wouldn’t, you’re doing a conv op then rescaling each of the 10 channels by some constant. That’s the same as just rescaling the weights of your conv module which would be faster.

The situation isn’t really supported as a result.

Its also kind of just going to perform poorly, you’d be better off doing one quantized conv than a quantized conv and a quantized mul since you’ll be introducing quantization error twice now. If you really wanted to do it like this, you could do QAT as normal, but then after convert, you have to manually quantize the parameter to avoid that error you’re seeing.