RuntimeError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::native_batch_norm' is only available for these backends: [CPUTensorId, MkldnnCPUTensorId, VariableTensorId]

I have a quantized model which is basically a resnet18. the quantization seems to go just fine until, when I try to load the quantized model from disk using sth like this :

def load_quantized(quantized_checkpoint_file_path):
    model = fvmodels.resnet18(pretrained=False, use_se=True)
    model.eval()
    model.fuse_model()
    # print(f'model: {model}')
    # Specify quantization configuration
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    # print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)
    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    checkpoint = torch.load(quantized_checkpoint_file_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint, strict=False)
    # model = torch.jit.load(quantized_checkpoint_file_path, map_location=torch.device('cpu'))
    fvmodels.print_size_of_model(model)
    return model

and while trying to use that :

model = load_quantized('path to model')
model.eval()
with torch.no_grad():
    for img, lbl in dtloader:
        features = model(img.unsqueeze(0))

I face the following error :

RuntimeError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::native_batch_norm' is only available for these backends: [CPUTensorId, MkldnnCPUTensorId, VariableTensorId].

This seems to be casued by the fact that the batchnorm layer is not fused! and the issue is I dont know how to fuse it. to be more specific here is the resnet model I have at hand :

class ResNet(nn.Module):

    def __init__(self, block, layers, use_se=True):
        self.inplanes = 64
        self.use_se = use_se
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # self.prelu = nn.PReLU()
        self.prelu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn2 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512 * 7 * 7, 512)
        self.bn3 = nn.BatchNorm1d(512)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.quant(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn2(x)
        x = self.dropout(x)
        # x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        x = self.bn3(x)

        x = self.dequant(x)
        return x

    def fuse_model(self):
        r"""Fuse conv/bn/relu modules in resnet models
        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        fuse_modules(self, [['conv1', 'bn1', 'prelu'],
                            ['bn2'],
                            ['bn3']], inplace=True)
        for m in self.modules():
            # print(m)
            if type(m) == Bottleneck or type(m) == BasicBlock or type(m) == IRBlock:
                m.fuse_model()

as you can see in the forward pass we have :

...
x = self.bn2(x)
x = self.dropout(x)

which is followed by a dropout and unlike previous ones, doesnt come with neither conv or relu!
the same thing goes to bn3 a couple of lines later:

...
x = self.fc(x)
x = self.bn3(x)
x = self.dequant(x)
....

So I’m not sure how I’m supposed to get around this. obviously the way I’m fusing is wrong:

def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'prelu'],
                            ['bn2'],
                            ['bn3']], inplace=True)
        for m in self.modules():
            # print(m)
            if type(m) == Bottleneck or type(m) == BasicBlock or type(m) == IRBlock:
                m.fuse_model()

For the sake of completeness here are the whole models :


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        
        self.add_relu = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        # out += residual
        # out = self.relu(out)
        out = self.add_relu.add_relu(out, residual)

        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride
        
        self.skip_add_relu = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.relu(out)
        out = self.skip_add_relu.add_relu(out, residual)
        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                            ['conv2', 'bn2', 'relu2'],
                            ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mult_xy = nn.quantized.FloatFunctional()

        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            # nn.PReLU(),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        # out = x*y 
        out = self.mult_xy.mul(x, y)
        return out

class IRBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        # self.prelu = nn.PReLU()
        self.prelu = nn.ReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

        self.add_residual_relu = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x

        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        # out += residual
        # out = self.prelu(out)

        # we may need to change prelu into relu and this, instead of add, use add_relu here
        out = self.add_residual_relu.add_relu(out, residual)
        # out = self.prelu(out)
        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'prelu'],
                            ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class ResNet(nn.Module):

    def __init__(self, block, layers, use_se=True):
        self.inplanes = 64
        self.use_se = use_se
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # self.prelu = nn.PReLU()
        self.prelu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn2 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512 * 7 * 7, 512)
        self.bn3 = nn.BatchNorm1d(512)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.quant(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn2(x)
        x = self.dropout(x)
        # x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        x = self.bn3(x)

        x = self.dequant(x)
        return x

    def fuse_model(self):
        r"""Fuse conv/bn/relu modules in resnet models
        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        fuse_modules(self, [['conv1', 'bn1', 'prelu'],
                            ['bn2'],
                            ['bn3']], inplace=True)
        for m in self.modules():
            # print(m)
            if type(m) == Bottleneck or type(m) == BasicBlock or type(m) == IRBlock:
                m.fuse_model()

def resnet18(pretrained, use_se, **kwargs):
    model = ResNet(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

side note:
Also the actual model (resnet18 can be found from this link in case someone might need it)
Config:
Im using Pytorch 1.5.0+cpu on windows 10 x64 v1803

Any help is greatly appreciated

fuse_modules(self, [['conv1', 'bn1', 'prelu'],
                            ['bn2'],
                            ['bn3']], inplace=True)

change into:
fuse_modules(self, [‘conv1’, ‘bn1’, ‘prelu’], inplace=True)

Thanks but, that was the initial attempt which results in the mentioned error as well.

Traceback (most recent call last):
  File "d:\Codes\org\python\Quantization\quantizer.py", line 265, in <module>
    features = model(img.unsqueeze(0))
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 418, in forward
    x = self.bn3(x)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\batchnorm.py", line 106, in forward
    exponential_average_factor, self.eps)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\functional.py", line 1923, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::native_batch_norm' is only available for these backends: [CPUTensorId, MkldnnCPUTensorId, VariableTensorId].

Can you try printing the quantized model after prepare and convert? We do support quantized batch_norm so nn.BatchNorm2d module should get replaced with quantized one.

Hi, here it is :
Ran using the latest nighly 1.7.0.dev20200714+cpu and torchvision-0.8.0.dev20200714+cpu

Size (MB): 87.218199
QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Model after being fused-prepared: ResNet(
  (conv1): Conv2d(
    3, 64, kernel_size=(3, 3), stride=(1, 1)
    (activation_post_process): HistogramObserver()
  )
  (bn1): Identity()
  (prelu): PReLU(num_parameters=1)
  (prelu_q): PReLU_Quantized(
    (quantized_op): FloatFunctional(
      (activation_post_process): HistogramObserver()
    )
    (quant): QuantStub(
      (activation_post_process): HistogramObserver()
    )
    (dequant): DeQuantStub()
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): IRBlock(
      (bn0): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=64, out_features=4, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=4, out_features=64, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=64, out_features=4, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=4, out_features=64, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
        (prelu_q): PReLU_Quantized(
          (quantized_op): FloatFunctional(
            (activation_post_process): HistogramObserver()
          )
          (quant): QuantStub(
            (activation_post_process): HistogramObserver()
          )
          (dequant): DeQuantStub()
        )
        (fc_q): Sequential(
          (0): Linear(
            in_features=64, out_features=4, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU_Quantized(
            (quantized_op): FloatFunctional(
              (activation_post_process): HistogramObserver()
            )
            (quant): QuantStub(
              (activation_post_process): HistogramObserver()
            )
            (dequant): DeQuantStub()
          )
          (2): Linear(
            in_features=4, out_features=64, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
    (1): IRBlock(
      (bn0): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=64, out_features=4, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=4, out_features=64, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=64, out_features=4, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=4, out_features=64, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
        (prelu_q): PReLU_Quantized(
          (quantized_op): FloatFunctional(
            (activation_post_process): HistogramObserver()
          )
          (quant): QuantStub(
            (activation_post_process): HistogramObserver()
          )
          (dequant): DeQuantStub()
        )
        (fc_q): Sequential(
          (0): Linear(
            in_features=64, out_features=4, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU_Quantized(
            (quantized_op): FloatFunctional(
              (activation_post_process): HistogramObserver()
            )
            (quant): QuantStub(
              (activation_post_process): HistogramObserver()
            )
            (dequant): DeQuantStub()
          )
          (2): Linear(
            in_features=4, out_features=64, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
  )
  (layer2): Sequential(
    (0): IRBlock(
      (bn0): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (conv2): Conv2d(
        64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(
          64, 128, kernel_size=(1, 1), stride=(2, 2)
          (activation_post_process): HistogramObserver()
        )
        (1): Identity()
      )
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=128, out_features=8, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=8, out_features=128, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=128, out_features=8, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=8, out_features=128, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
        (prelu_q): PReLU_Quantized(
          (quantized_op): FloatFunctional(
            (activation_post_process): HistogramObserver()
          )
          (quant): QuantStub(
            (activation_post_process): HistogramObserver()
          )
          (dequant): DeQuantStub()
        )
        (fc_q): Sequential(
          (0): Linear(
            in_features=128, out_features=8, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU_Quantized(
            (quantized_op): FloatFunctional(
              (activation_post_process): HistogramObserver()
            )
            (quant): QuantStub(
              (activation_post_process): HistogramObserver()
            )
            (dequant): DeQuantStub()
          )
          (2): Linear(
            in_features=8, out_features=128, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
    (1): IRBlock(
      (bn0): BatchNorm2d(
        128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): Conv2d(
        128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (conv2): Conv2d(
        128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=128, out_features=8, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=8, out_features=128, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=128, out_features=8, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=8, out_features=128, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
        (prelu_q): PReLU_Quantized(
          (quantized_op): FloatFunctional(
            (activation_post_process): HistogramObserver()
          )
          (quant): QuantStub(
            (activation_post_process): HistogramObserver()
          )
          (dequant): DeQuantStub()
        )
        (fc_q): Sequential(
          (0): Linear(
            in_features=128, out_features=8, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU_Quantized(
            (quantized_op): FloatFunctional(
              (activation_post_process): HistogramObserver()
            )
            (quant): QuantStub(
              (activation_post_process): HistogramObserver()
            )
            (dequant): DeQuantStub()
          )
          (2): Linear(
            in_features=8, out_features=128, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
  )
  (layer3): Sequential(
    (0): IRBlock(
      (bn0): BatchNorm2d(
        128, eps=1e-05, momentum=0.

and for the sake of completeness, here are the whole modules used :

Summary
class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        # this is how we do it 
        # pos = torch.relu(inputs)
        # neg = -alpha * torch.relu(-inputs)
        # res3 = pos + neg
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        
        self.add_relu = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        # out += residual
        # out = self.relu(out)
        out = self.add_relu.add_relu(out, residual)

        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride
        
        self.skip_add_relu = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.relu(out)
        out = self.skip_add_relu.add_relu(out, residual)
        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                            ['conv2', 'bn2', 'relu2'],
                            ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mult_xy = nn.quantized.FloatFunctional()

        self.fc = nn.Sequential(
                                nn.Linear(channel, channel // reduction),
                                nn.PReLU(),
                                # nn.ReLU(),
                                nn.Linear(channel // reduction, channel),
                                nn.Sigmoid()
                                )
        self.fc1 = self.fc[0]
        self.prelu = self.fc[1]
        self.fc2 = self.fc[2]
        self.sigmoid = self.fc[3]
        self.prelu_q = PReLU_Quantized(self.prelu)

    def forward(self, x):
        print(f'<inside se forward:>')
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # y = self.fc(y).view(b, c, 1, 1)
        y = self.fc1(y)
        print(f'X: {y}')
        y = self.prelu_q(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        print('--------------------------')
        # out = x*y 
        out = self.mult_xy.mul(x, y)
        return out

class IRBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
        # self.prelu = nn.ReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

        self.add_residual_relu = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x

        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)

        # out = self.prelu(out)
        out = self.prelu_q(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        # out += residual
        # out = self.prelu(out)

        # we may need to change prelu into relu and this, instead of add, use add_relu here
        out = self.add_residual_relu.add_relu(out, residual)
        # out = self.prelu(out)
        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1'],# 'prelu'],
                            ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class ResNet(nn.Module):

    def __init__(self, block, layers, use_se=True):
        self.inplanes = 64
        self.use_se = use_se
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
        # self.prelu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn2 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512 * 7 * 7, 512)
        self.bn3 = nn.BatchNorm1d(512)

        # self.bn2_q = BatchNorm2d_Quantized(self.bn2)
        # self.bn3_q = BatchNorm1d_Quantized(self.bn3)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.quant(x)

        x = self.conv1(x)
        x = self.bn1(x)

        # x = self.prelu(x)
        x = self.prelu_q(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn2(x)
        # x = self.bn2_q(x)
        x = self.dropout(x)
        # x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        x = self.bn3(x)
        # x = self.bn3_q(x)

        x = self.dequant(x)
        return x

    def fuse_model(self):
        r"""Fuse conv/bn/relu modules in resnet models
        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        fuse_modules(self, [['conv1', 'bn1'],# 'prelu'],
                            # ['bn2'],  ['bn3']
                            ], inplace=True)
        for m in self.modules():
            # print(m)
            if type(m) == Bottleneck or type(m) == BasicBlock or type(m) == IRBlock:
                m.fuse_model()

def resnet18(pretrained, use_se, **kwargs):
    model = ResNet(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

Here is another sample, this is the model output when I removed all PReLUs and used ReLUs isntead(incase it was a hinderance):

Size (MB): 87.205847
QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Model after quantization(converted-prepared): ResNet(
  (conv1): ConvReLU2d(
    (0): Conv2d(
      3, 64, kernel_size=(3, 3), stride=(1, 1)
      (activation_post_process): HistogramObserver()
    )
    (1): ReLU(
      (activation_post_process): HistogramObserver()
    )
  )
  (bn1): Identity()
  (prelu): PReLU(num_parameters=1)
  (prelu_q): PReLU_Quantized(
    (quantized_op): FloatFunctional(
      (activation_post_process): HistogramObserver()
    )
    (quant): QuantStub(
      (activation_post_process): HistogramObserver()
    )
    (dequant): DeQuantStub()
  )
  (reluooo): Identity()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): IRBlock(
      (bn0): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): ConvReLU2d(
        (0): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (activation_post_process): HistogramObserver()
        )
        (1): ReLU(
          (activation_post_process): HistogramObserver()
        )
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (reluooo): Identity()
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=64, out_features=4, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=4, out_features=64, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=64, out_features=4, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=4, out_features=64, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
    (1): IRBlock(
      (bn0): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): ConvReLU2d(
        (0): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (activation_post_process): HistogramObserver()
        )
        (1): ReLU(
          (activation_post_process): HistogramObserver()
        )
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (reluooo): Identity()
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=64, out_features=4, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=4, out_features=64, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=64, out_features=4, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=4, out_features=64, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
  )
  (layer2): Sequential(
    (0): IRBlock(
      (bn0): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): ConvReLU2d(
        (0): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (activation_post_process): HistogramObserver()
        )
        (1): ReLU(
          (activation_post_process): HistogramObserver()
        )
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (reluooo): Identity()
      (conv2): Conv2d(
        64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(
          64, 128, kernel_size=(1, 1), stride=(2, 2)
          (activation_post_process): HistogramObserver()
        )
        (1): Identity()
      )
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=128, out_features=8, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=8, out_features=128, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=128, out_features=8, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=8, out_features=128, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
    (1): IRBlock(
      (bn0): BatchNorm2d(
        128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): HistogramObserver()
      )
      (conv1): ConvReLU2d(
        (0): Conv2d(
          128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (activation_post_process): HistogramObserver()
        )
        (1): ReLU(
          (activation_post_process): HistogramObserver()
        )
      )
      (bn1): Identity()
      (prelu): PReLU(num_parameters=1)
      (prelu_q): PReLU_Quantized(
        (quantized_op): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (quant): QuantStub(
          (activation_post_process): HistogramObserver()
        )
        (dequant): DeQuantStub()
      )
      (reluooo): Identity()
      (conv2): Conv2d(
        128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (se): SEBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (mult_xy): FloatFunctional(
          (activation_post_process): HistogramObserver()
        )
        (fc): Sequential(
          (0): Linear(
            in_features=128, out_features=8, bias=True
            (activation_post_process): HistogramObserver()
          )
          (1): PReLU(num_parameters=1)
          (2): Linear(
            in_features=8, out_features=128, bias=True
            (activation_post_process): HistogramObserver()
          )
          (3): Sigmoid()
        )
        (fc1): Linear(
          in_features=128, out_features=8, bias=True
          (activation_post_process): HistogramObserver()
        )
        (prelu): PReLU(num_parameters=1)
        (fc2): Linear(
          in_features=8, out_features=128, bias=True
          (activation_post_process): HistogramObserver()
        )
        (sigmoid): Sigmoid()
      )
      (add_residual_relu): FloatFunctional(
        (activation_post_process): HistogramObserver()
      )
    )
  )
  (layer3): Sequential(
    (0)
Post Training Quantization Prepare: Inserting Observers

 Inverted Residual Block:After observer insertion

 ConvReLU2d(
  (0): Conv2d(
    3, 64, kernel_size=(3, 3), stride=(1, 1)
    (activation_post_process): HistogramObserver()
  )
  (1): ReLU(
    (activation_post_process): HistogramObserver()
  )
)

This is the error I get when using this model (above):

--------------------------
Traceback (most recent call last):
  File "d:\Codes\org\python\Quantization\quantizer.py", line 270, in <module>
    test_the_model(True)
  File "d:\Codes\org\python\Quantization\quantizer.py", line 218, in test_the_model
    check_and_tell(model, pic1, pic2)
  File "d:\Codes\org\python\Quantization\quantizer.py", line 203, in check_and_tell
    embd1 = model(img1.unsqueeze(0))
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 599, in forward
    x = self.bn3(x)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\batchnorm.py", line 136, in forward
    self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\functional.py", line 2039, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU' backend. 'aten::native_batch_norm' is only available for these backends: [CPU, MkldnnCPU, BackendSelect, Named, Autograd, Profiler, Tracer, Autocast, Batched].

CPU: registered at aten\src\ATen\CPUType.cpp:1594 [kernel]
MkldnnCPU: registered at aten\src\ATen\MkldnnCPUType.cpp:139 [kernel]
BackendSelect: fallthrough registered at ..\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Autograd: registered at ..\torch\csrc\autograd\generated\VariableType_0.cpp:7879 [kernel]
Profiler: registered at ..\torch\csrc\autograd\generated\ProfiledType_0.cpp:2050 [kernel]
Tracer: registered at ..\torch\csrc\autograd\generated\TraceType_0.cpp:8256 [kernel]
Autocast: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:375 [backend fallback]
Batched: registered at ..\aten\src\ATen\BatchingRegistrations.cpp:149 [backend fallback]

I tried commenting out self.bn(x), and the code ran through. Is there any other solution to this problem?

Thanks, but thats not a solution to me, removing bn drastically affects the performance and aside from that, @supriyar says Pytorch has a quantized version of BatchNorm in place and it should have got converted in first place!
So its not known what is missing or what else needs to be done.

You can replace Linear with a 1*1 convolutional layer, and then merge the convolutional layer and bn layer. I have tried and solved my problem

Thanks, but the problem is, I have other instances of BN where they are used alone! one instance is in IRBlock where the first layer is bn!
So I need to fix this properly

You can try to add a 1*1 convolutional layer before the bn layer

That way I have to retrain the model as the 1x1 weights are uninitialized

Looking at the first conv of your model after convert, it doesn’t seem like it is actually quantized (it should be QuantizedConv), same for the subsequent modules. One thing to debug would be why the modules are not getting replaced.

Are you calling model.eval() before running convert?

Hi, Thanks, but no that model is not yet converted, what is printed up there, is just the output after running fuse_model() and then torch.quantization.prepare. if I comment out the single bns, (and also repalce PReLUs to not face the current issues), the final model does get quantized (its size becomes 22Mb from 88Mb and you see the QuantizedConv2d, etc as well.

the error message looks like you are trying to pass a quantized input to BN, but BN is not quantized. So, you’d need to either fuse it to the preceding module, quantize it, or make sure the input is converted to floating point. Here is a toy example of expected behavior:

import torch
import torch.nn as nn

class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.bn1 = nn.BatchNorm2d(1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.bn2 = nn.BatchNorm2d(1)
    
    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x
    
m = M()
m.qconfig = torch.quantization.default_qconfig
m.eval()

torch.quantization.fuse_modules(
    m, 
    [
        ['conv1', 'bn1'], # fuse bn1 into conv1
        # for example's sake, don't fuse conv2 and bn2
    ],
    inplace=True)

torch.quantization.prepare(m, inplace=True)

# toy calibration
data = torch.randn(32, 1, 16, 16)
m(data)

torch.quantization.convert(m, inplace=True)
# self.bn1 was fused with conv1 earlier
# self.bn2 will be QuantizedBatchNorm2d
print(m)
1 Like

Thanks a lot. good point. on a normal model this looks alright, but in the self contained example I made, this doesnt apply,.
Here have a look :
Here is a self contained example with Resnet18 and SimpleNetwork ( a simple 2 layered CNN) using fake data to demonstrate the problem. You can change the use_relu and disable_single_bns to see different results:

import os
from os.path import abspath, dirname, join
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
import torchvision.transforms as transforms
from torch.quantization import fuse_modules

use_relu = False
disable_single_bns = False

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        # this is how we do it 
        # pos = torch.relu(inputs)
        # neg = -alpha * torch.relu(-inputs)
        # res3 = pos + neg
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.add_relu = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.relu(out)
        out = self.add_relu.add_relu(out, residual)
        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride
        self.skip_add_relu = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.relu(out)
        out = self.skip_add_relu.add_relu(out, residual)
        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                            ['conv2', 'bn2', 'relu2'],
                            ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mult_xy = nn.quantized.FloatFunctional()
        self.fc = nn.Sequential(nn.Linear(channel, channel // reduction),
                                nn.PReLU(),
                                nn.Linear(channel // reduction, channel),
                                nn.Sigmoid())
        self.fc1 = self.fc[0]
        self.prelu = self.fc[1]
        self.fc2 = self.fc[2]
        self.sigmoid = self.fc[3]
        self.prelu_q = PReLU_Quantized(self.prelu)
        if use_relu:
            self.prelu_q_or_relu = torch.relu
        else:
            self.prelu_q_or_relu = self.prelu_q

    def forward(self, x):
        # print(f'<inside se forward:>')
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # y = self.fc(y).view(b, c, 1, 1)
        y = self.fc1(y)
        y = self.prelu_q_or_relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        # print('--------------------------')
        # out = x*y 
        out = self.mult_xy.mul(x, y)
        return out

class IRBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        if disable_single_bns:
            self.bn0_or_identity = torch.nn.Identity()
        else:
            self.bn0_or_identity = self.bn0

        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
        
        if use_relu:
            self.prelu_q_or_relu = torch.relu
        else:
            self.prelu_q_or_relu = self.prelu_q

        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        # if self.use_se:
        self.se = SEBlock(planes)
        self.add_residual = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x
        # TODO:
        # this needs to be quantized as well!
        out = self.bn0_or_identity(x)

        out = self.conv1(out)
        out = self.bn1(out)
        # out = self.prelu(out)
        out = self.prelu_q_or_relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.prelu(out)
        out = self.prelu_q_or_relu(out)
        # we may need to change prelu into relu and instead of add, use add_relu here
        out = self.add_residual.add(out, residual)
        return out

    def fuse_model(self):
        fuse_modules(self, [# ['bn0'],
                            ['conv1', 'bn1'],
                            ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class ResNet(nn.Module):

    def __init__(self, block, layers, use_se=True):
        self.inplanes = 64
        self.use_se = use_se
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
        # This is to only get rid of the unimplemented CPUQuantization type error
        # when we use PReLU_Quantized during test time
        if use_relu:
            self.prelu_q_or_relu = torch.relu
        else:
             self.prelu_q_or_relu = self.prelu_q

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn2 = nn.BatchNorm2d(512)
        # This is to get around the single BatchNorms not getting fused and thus causing 
        # a RuntimeError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU' backend.
        # 'aten::native_batch_norm' is only available for these backends: [CPU, MkldnnCPU, BackendSelect, Named, Autograd, Profiler, Tracer, Autocast, Batched].
        # during test time
        if disable_single_bns:
            self.bn2_or_identity = torch.nn.Identity()
        else:
            self.bn2_or_identity = self.bn2

        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512 * 7 * 7, 512)
        self.bn3 = nn.BatchNorm1d(512)
        if disable_single_bns:
            self.bn3_or_identity = torch.nn.Identity()
        else:
            self.bn3_or_identity = self.bn3
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.quant(x)
        x = self.conv1(x)
        # TODO: single bn needs to be fused
        x = self.bn1(x)

        # x = self.prelu(x)
        x = self.prelu_q_or_relu(x)

        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn2_or_identity(x)
        x = self.dropout(x)
        # x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        # TODO: single bn needs to be fused
        x = self.bn3_or_identity(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        r"""Fuse conv/bn/relu modules in resnet models
        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
        fuse_modules(self, ['conv1', 'bn1'], inplace=True)
        for m in self.modules():
            if type(m) == Bottleneck or type(m) == BasicBlock or type(m) == IRBlock:
                m.fuse_model()

class SimpleNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.relu1 = nn.ReLU()

        self.prelu_q = PReLU_Quantized(nn.PReLU())
        self.bn = nn.BatchNorm2d(10)

        self.prelu_q_or_relu = torch.relu if use_relu else self.prelu_q
        self.bn_or_identity = nn.Identity() if disable_single_bns else self.bn    

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.prelu_q_or_relu(x)
        x = self.bn_or_identity(x)

        x = self.dequant(x)
        return x

def resnet18(use_se=True, **kwargs):
    return ResNet(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs)

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def evaluate(model, data_loader, eval_batches):
    model.eval()
    with torch.no_grad():
        for i, (image, target) in enumerate(data_loader):
            features = model(image)
            print(f'{i})feature dims: {features.shape}')
            if i >= eval_batches:
                return

def load_quantized(model, quantized_checkpoint_file_path):
    model.eval()
    if type(model) == ResNet:
        model.fuse_model()
    # Specify quantization configuration
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)
    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    checkpoint = torch.load(quantized_checkpoint_file_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    print_size_of_model(model)
    return model

def test_the_model(model, dtloader):
    current_dir = abspath(dirname(__file__))
    model = load_quantized(model, join(current_dir, 'data', 'model_quantized_jit.pth'))
    model.eval()
    img, _ = next(iter(dtloader))
    embd1 = model(img)

def quantize_model(model, dtloader):
    calibration_batches = 10 
    saved_model_dir = 'data'
    scripted_quantized_model_file = 'model_quantized_jit.pth'
    # model = resnet18()
    model.eval()
    if type(model) == ResNet:
        model.fuse_model()
    print_size_of_model(model)
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)

    print(f'Model after fusion(prepared): {model}')

    # Calibrate first
    print('Post Training Quantization Prepare: Inserting Observers')
    print('\n Inverted Residual Block:After observer insertion \n\n', model.conv1)

    # Calibrate with the training set
    evaluate(model, dtloader, eval_batches=calibration_batches)
    print('Post Training Quantization: Calibration done')

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    print('Post Training Quantization: Convert done')
    print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n', model.conv1)

    print("Size of model after quantization")
    print_size_of_model(model)
    script = torch.jit.script(model)
    path_tosave = join(dirname(abspath(__file__)), saved_model_dir, scripted_quantized_model_file)
    print(f'path to save: {path_tosave}')
    with open(path_tosave, 'wb') as f:
        torch.save(model.state_dict(), f)

    print(f'model after quantization (prepared and converted:) {model}')
    # torch.jit.save(script, path_tosave)

dataset = FakeData(1000, image_size=(3, 112, 112), num_classes=5, transform=transforms.ToTensor())
data_loader = DataLoader(dataset, batch_size=1)

# quantize the model 
model = resnet18()
# model = SimpleNetwork()
quantize_model(model, data_loader)

# and load and test the quantized model
model = resnet18()
# model = SimpleNetwork()
test_the_model(model, data_loader)

I changed the SimpleNetwork, based on what you suggested and It doesnt fail anymore, but this is not the case with the ResNet18.

I’ll try to dig a bit more and see why I find .
Thanks alot for your time really appreciate it

I noticed two things so far:
Pytorch has issues with branches in the model for some reason. that is, lets consider SimpleNetwork here.


class SimpleNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.relu1 = nn.ReLU()

        self.prelu_q = PReLU_Quantized(nn.PReLU())
        self.bn = nn.BatchNorm2d(10)

        self.prelu_q_or_relu = torch.relu if use_relu else self.prelu_q
        self.bn_or_identity = nn.Identity() if disable_single_bns else self.bn    

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.prelu_q_or_relu(x)
        x = self.bn_or_identity(x)

        x = self.dequant(x)
        return x

This by default results in the infamous error stated in the op. However, if I simply remove :

self.bn_or_identity = nn.Identity() if disable_single_bns else self.bn    

and simply use the

self.bn = nn.BatchNorm2d(10)

in the forward pass, I no longer see that error!.
I tried to do the same thing to ResNet18, and it seems, all previous bns are fine except the bn3 in the ResNet model (the penultimate layer) which regardless of what I change, still give that error !

what is the error when you use
self.bn_or_identity = nn.Identity() if disable_single_bns else self.bn ? I think this is what you need to do