Fusing a QAT model post-training

Hello, I have a QAT model that I have already trained. After training though, I would like to fuse the conv, bn, and relu layers of the model. However, after doing this, I seem to lose the quantization scale for the resulting ConvReLU2d layer. I’ve created a minimial example.

class Model(th.nn.Module):
    def __init__(self):
        self.quant = th.quantization.QuantStub()
        self.conv = th.nn.Conv2d(1, 3, 3, bias=False)
        self.bn = th.nn.BatchNorm2d(3)
        self.relu = th.nn.ReLU()
        self.dequant = th.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# Initialize model
m = Model()

# Set qconfig
m.qconfig = th.ao.quantization.get_default_qconfig("fbgemm")

# Train
x = th.rand(64, 1, 32, 30)
pm = th.quantization.prepare_qat(m.train())
y = pm(x)

At this point, I can retrieve the quantized model as normal:

qm = th.ao.quantization.convert(pm.eval())

And I can access the quantized weights of the convolutional layer as normal:


However, what I really want to do is fuse the conv, bn and relu layers together. I’ve tried the following:

fused_pm = th.ao.quantization.fuse_modules(pm, [['conv','bn','relu']])

This results in the following error:

AssertionError: did not find fuser method for: (<class 'torch.ao.nn.qat.modules.conv.Conv2d'>, <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, <class 'torch.nn.modules.activation.ReLU'>) 

I receive a similar error if I instead try to fuse the quantized model (qm).

Is there a way to fuse a QAT model after training and access the updated quantized version? Thank you.

the order for QAT with fused module should be:

  1. fuse_modules
  2. prepare_qat
  3. convert

can you try this?

In my initial example, I realized I was using prepare instead of prepare_qat. I’ve updated my question since. At this point I have my prepared and trained model. I tried fusing by doing this:

fused_pm = th.quantization.fuse_modules(pm.eval(), [["conv", "bn", "relu"]])

However, I get the following error:

AssertionError: did not find fuser method for: (<class 'torch.ao.nn.qat.modules.conv.Conv2d'>, <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, <class 'torch.nn.modules.activation.ReLU'>) 

Is fusion not supported for QAT versions of conv and bn? Or is there a different function I should be using to fuse?

you should apply fusion before prepare_qat

or you could try out our new API: (prototype) PyTorch 2 Export Quantization-Aware Training (QAT) — PyTorch Tutorials 2.3.0+cu121 documentation all these details will be taken care of for you