"conv", "bn", "relu" becomes fused as ConvReLU2d instead of ConvBnReLU2d

Hello, when I fuse conv, BN, and ReLU layers together using torch.ao.quantization.fuse_modules(m, modules_to_fuse), the fused module does not include the BN layer unlike in the quantizable ResNet18.

In the sample code, conv, BN, and ReLU results to ConvReLU2d:

From the quantizable ResNet18, there’s ConvBnReLU2d:

Thank you in advance for your help!

I haven’t checked the internal code but maybe the parameters and buffers of the batchnorm layers were all fused into the parameters of the conv? Have you checked some parameter values?

It seems is_qat arg for torch.ao.quantization.fuse_modules() defines state for training or evaluation. Since batchnorm layer not used for static quantization and will needed for QAT as far as I know, it’s internally merged into conv layer when QAT disabled. Here’s the quick summary:

a little code snippet from torchvision to apply proper fuse.

def _fuse_modules(
    model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any
):
    if is_qat is None:
        is_qat = model.training
    method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
    return method(model, modules_to_fuse, **kwargs)

in line method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules shows us QAT fuse has different approach from default one.

>>> m  # module itself
Conv2dNormActivation(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

>>> m.eval()
Conv2dNormActivation(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

>>> _fuse_modules(m, modules_to_fuse, is_qat=False)
Conv2dNormActivation(
  (0): ConvReLU2d(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
  )
  (1): Identity()
  (2): Identity()
)

>>> _fuse_modules(m, modules_to_fuse, is_qat=True)
Conv2dNormActivation(
  (0): ConvBnReLU2d(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (1): Identity()
  (2): Identity()
)
1 Like

Yes, for PTQ (post training quantization), the weights of “bn” is fused into “conv”, so we have torch.ao.nn.quantized.ConvReLU2d, for QAT (quantization aware training), we first fuse “conv” - “bn” - “relu” to torch.ao.nn.intrinsic.ConvBnReLU2d module, then it will be swapped as torch.ao.nn.qat.ConvBnReLU2d module which will simulate the numerics of a torch.ao.nn.quantized.ConvReLU2d module at training time

2 Likes