Satic Quantization of Inception Resnet Model

Hi, I am trying to perform static quantization of the Inception ResNet model. I made some minor modifications. here is the code for the model

import os
import requests
from requests.adapters import HTTPAdapter
import torch
from torch import nn
from torch.nn import functional as F
from torch.quantization import QuantStub, DeQuantStub

from facenet_pytorch.models.utils.download import download_url_to_file


class BasicConv2d(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(
            in_planes, out_planes,
            kernel_size=kernel_size, stride=stride,
            padding=padding, bias=False
        ) # verify bias false
        self.bn = nn.BatchNorm2d(
            out_planes,
            eps=0.001, # value found in tensorflow
            momentum=0.1, # default pytorch value
            affine=True
        )
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Block35(nn.Module):

    def __init__(self, scale=1.0):
        super().__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(256, 32, kernel_size=1, stride=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(256, 32, kernel_size=1, stride=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
        )

        self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)
        self.ff1 = nn.quantized.FloatFunctional()
        self.ff2 = nn.quantized.FloatFunctional()

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        out = self.conv2d(out)
        # out = out * self.scale + x
        out = self.ff2.add(self.ff1.mul(out, self.scale), x)
        out = self.relu(out)
        return out


class Block17(nn.Module):

    def __init__(self, scale=1.0):
        super().__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(896, 128, kernel_size=1, stride=1),
            BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0))
        )

        self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        out = self.relu(out)
        return out


class Block8(nn.Module):

    def __init__(self, scale=1.0, noReLU=False):
        super().__init__()

        self.scale = scale
        self.noReLU = noReLU

        self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(1792, 192, kernel_size=1, stride=1),
            BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
            BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0))
        )

        self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1)
        if not self.noReLU:
            self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        if not self.noReLU:
            out = self.relu(out)
        return out


class Mixed_6a(nn.Module):

    def __init__(self):
        super().__init__()

        self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2)

        self.branch1 = nn.Sequential(
            BasicConv2d(256, 192, kernel_size=1, stride=1),
            BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1),
            BasicConv2d(192, 256, kernel_size=3, stride=2)
        )

        self.branch2 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        return out


class Mixed_7a(nn.Module):

    def __init__(self):
        super().__init__()

        self.branch0 = nn.Sequential(
            BasicConv2d(896, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 384, kernel_size=3, stride=2)
        )

        self.branch1 = nn.Sequential(
            BasicConv2d(896, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 256, kernel_size=3, stride=2)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(896, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
            BasicConv2d(256, 256, kernel_size=3, stride=2)
        )

        self.branch3 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class InceptionResnetV1(nn.Module):
    """Inception Resnet V1 model with optional loading of pretrained weights.
    Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface
    datasets. Pretrained state_dicts are automatically downloaded on model instantiation if
    requested and cached in the torch cache. Subsequent instantiations use the cache rather than
    redownloading.
    Keyword Arguments:
        pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'.
            (default: {None})
        classify {bool} -- Whether the model should output classification probabilities or feature
            embeddings. (default: {False})
        num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not
            equal to that used for the pretrained model, the final linear layer will be randomly
            initialized. (default: {None})
        dropout_prob {float} -- Dropout probability. (default: {0.6})
    """
    def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
        super().__init__()

        # Set simple attributes
        self.pretrained = pretrained
        self.classify = classify
        self.num_classes = num_classes

        if pretrained == 'vggface2':
            tmp_classes = 8631
        elif pretrained == 'casia-webface':
            tmp_classes = 10575
        elif pretrained is None and self.classify and self.num_classes is None:
            raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')


        # Define layers
        self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
        self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.maxpool_3a = nn.MaxPool2d(3, stride=2)
        self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
        self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
        self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2)
        self.repeat_1 = nn.Sequential(
            Block35(scale=0.17),
            Block35(scale=0.17),
            Block35(scale=0.17),
            Block35(scale=0.17),
            Block35(scale=0.17),
        )
        self.mixed_6a = Mixed_6a()
        self.repeat_2 = nn.Sequential(
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
            Block17(scale=0.10),
        )
        self.mixed_7a = Mixed_7a()
        self.repeat_3 = nn.Sequential(
            Block8(scale=0.20),
            Block8(scale=0.20),
            Block8(scale=0.20),
            Block8(scale=0.20),
            Block8(scale=0.20),
        )
        self.block8 = Block8(noReLU=True)
        self.avgpool_1a = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(dropout_prob)
        self.last_linear = nn.Linear(1792, 512, bias=False)
        self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

        if pretrained is not None:
            self.logits = nn.Linear(512, tmp_classes)
            load_weights(self, pretrained)

        if self.classify and self.num_classes is not None:
            self.logits = nn.Linear(512, self.num_classes)

        self.device = torch.device('cpu')
        if device is not None:
            self.device = device
            self.to(device)

    def forward(self, x):
        """Calculate embeddings or logits given a batch of input image tensors.
        Arguments:
            x {torch.tensor} -- Batch of image tensors representing faces.
        Returns:
            torch.tensor -- Batch of embedding vectors or multinomial logits.
        """
        x = self.quant(x)
        x = self.conv2d_1a(x)
        x = self.conv2d_2a(x)
        x = self.conv2d_2b(x)
        x = self.maxpool_3a(x)
        x = self.conv2d_3b(x)
        x = self.conv2d_4a(x)
        x = self.conv2d_4b(x)
        x = self.repeat_1(x)
        x = self.mixed_6a(x)
        x = self.repeat_2(x)
        x = self.mixed_7a(x)
        x = self.repeat_3(x)
        x = self.block8(x)
        x = self.avgpool_1a(x)
        x = self.dropout(x)
        x = self.last_linear(x.view(x.shape[0], -1))
        x = self.last_bn(x)        
        if self.classify:
            x = self.logits(x)
        else:
            x = F.normalize(x, p=2, dim=1)
        x = self.dequant(x)
        return x
    
    def fuse_model(self):
        for m in self.modules():
            if type(m) == BasicConv2d:
                torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)


def load_weights(mdl, name):
    """Download pretrained state_dict and load into model.
    Arguments:
        mdl {torch.nn.Module} -- Pytorch model.
        name {str} -- Name of dataset that was used to generate pretrained state_dict.
    Raises:
        ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
    """
    if name == 'vggface2':
        path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt'
    elif name == 'casia-webface':
        path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt'
    else:
        raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')

    model_dir = os.path.join(get_torch_home(), 'checkpoints')
    os.makedirs(model_dir, exist_ok=True)

    cached_file = os.path.join(model_dir, os.path.basename(path))
    if not os.path.exists(cached_file):
        download_url_to_file(path, cached_file)

    state_dict = torch.load(cached_file)
    mdl.load_state_dict(state_dict)


def get_torch_home():
    torch_home = os.path.expanduser(
        os.getenv(
            'TORCH_HOME',
            os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')
        )
    )
    return torch_home

Model loading and configuration setting section

criterion = nn.CrossEntropyLoss()
model_inception_resnet = InceptionResnetV1(pretrained='vggface2', classify=True).eval()
# Fuse Conv, bn and relu
model_inception_resnet.fuse_model()

# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
model_inception_resnet.qconfig = torch.quantization.default_qconfig
print(model_inception_resnet.qconfig)
torch.quantization.prepare(model_inception_resnet, inplace=True)

# Convert to quantized model
torch.backends.quantized.engine = 'qnnpack'
torch.quantization.convert(model_inception_resnet, inplace=True)

After this when I am trying to evaluate the accuracy of the model over VGGFace2 dataset I am getting an error stating

RuntimeError: Could not run ‘aten::dequantize.self’ with arguments from the ‘CPU’ backend. ‘aten::dequantize.self’ is only available for these backends: [QuantizedCPU, QuantizedCUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

The full error output is provided below. Any help to resolve this would be appreciated. Thanks!

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-8ae51e49d6c3> in <module>
      1 ### Accuracy of quantized model on validation Set
----> 2 top1, top5 = evaluate(model_inception_resnet, criterion, train_loader, neval_batches=num_eval_batches)
      3 print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))

<ipython-input-4-46bd23fe98da> in evaluate(model, criterion, data_loader, neval_batches)
     48         for image, target in data_loader:
     49             # output = model(image)
---> 50             output = model_inception_resnet(transforms.ToTensor()(image).unsqueeze(0))
     51             loss = criterion(output, torch.tensor([target]))
     52             cnt += 1

/home/canservers/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-18-8074618c5a08> in forward(self, x)
    303         else:
    304             x = F.normalize(x, p=2, dim=1)
--> 305         x = self.dequant(x)
    306         return x
    307 

/home/canservers/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/home/canservers/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py in forward(self, Xq)
     78 
     79     def forward(self, Xq):
---> 80         return Xq.dequantize()
     81 
     82     @staticmethod

it means you are dequantizing a Tensor that is not quantized, meaning x = self.dequant(x) is not placed at a correct place. I think you probably want to move it to after self.logits and before normalize, e.g.:

 if self.classify:
            x = self.logits(x)
            x = self.dequant(x)
        else:
            x = self.dequant(x)
            x = F.normalize(x, p=2, dim=1)
        return x

assuming self.logits is quantized and self.last_bn is quantized.

But in general you will need to manually place this at a correct place for eager mode quantization.

Hi Jerry, Thanks for your prompt response.

I am having another issue in the model on the following line in class Block35

I was getting an error stating it is getting a float value instead of tensor. So I changed it into

out = self.ff1.mul(torch.tensor(out), torch.tensor(self.scale))
out = self.ff2.add(out, x)

But now I am getting an error

RuntimeError: Could not run 'aten::empty_strided' with arguments from the 'QuantizedCPU' backend. 'aten::empty_strided' is only available for these backends: ........

How can I resolve this? Do I need to add a QuantStub() and DeQuantStub() block in each of the forward function pass of the individual class?

if x is a scalar value, can you try using out = self.ff2.add_scalar(self.ff1.mul(out, self.scale), x) instead?

I tried what you suggested. This is what I got. Any suggestion would be appreciated. Thanks

RuntimeError: Overloaded torch operator invoked from Python failed to many any schema:
quantized::mul() Expected a value of type 'Tensor' for argument 'qb' but instead found type 'float'.
Position: 1
Value: 0.17
Declaration: quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point) -> (Tensor qc)
Cast error details: Unable to cast Python instance to C++ type (compile in debug mode for details)

quantized::mul() expected at most 3 argument(s) but received 4 argument(s). Declaration: quantized::mul.out(Tensor qa, Tensor qb, Tensor(a!) out) -> (Tensor(a!) out)

quantized::mul() expected at most 2 argument(s) but received 4 argument(s). Declaration: quantized::mul.Scalar(Tensor qa, Scalar b) -> (Tensor qc)

quantized::mul() expected at most 3 argument(s) but received 4 argument(s). Declaration: quantized::mul.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> (Tensor(a!) out)

A few more details abut the error message location

/home/canservers/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-3-12297d4c2fda> in forward(self, x)
     66         out = torch.cat((x0, x1, x2), 1)
     67         out = self.conv2d(out)
---> 68         out = self.ff2.add_scalar(self.ff1.mul(out, self.scale), x)
     69         out = self.relu(out)
     70         return out

/home/canservers/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/quantized/modules/functional_modules.py in mul(self, x, y)
    158     def mul(self, x, y):
    159         # type: (Tensor, Tensor) -> Tensor
--> 160         r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
    161         r = self.activation_post_process(r)
    162         return r

it looks like self.ff1.mul should also be self.ff1.mul_scalar

So finally I managed to get back to the Inception model and try out your suggestion. After a couple of tries this is what worked finally

out = self.ff2.add(self.ff1.mul_scalar(out, self.scale), x)

I had to make a few other changes in other blocks to incorporate this logic. The fuse_model that I wrote can probably be improved as well. The notebook is here if someone wants to have a look.

There is still some issue with torch.jit.script saving but that is another issue

Thanks @jerryzh168 for your help! :blush: This can be closed out.

1 Like