Quantized QNNPACK model is slowing down on android

Hi! My quantized segmentation model is slowing down on android , as well as on desktop cpu. Can’t figure out why. I tried different way to quantize, used different layers. Here I attached a full example with QNNPACK

For android I use. Also I tried different versions here

  implementation 'org.pytorch:pytorch_android:1.5.0-SNAPSHOT'
   implementation 'org.pytorch:pytorch_android_torchvision:1.5.0-SNAPSHOT'
import torch
import os

from torch import nn
from torchvision.models.resnet import BasicBlock, ResNet
from torch.quantization import fuse_modules
from torch.nn import functional as F


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class QuantizableBasicBlock(BasicBlock):
    def __init__(self, *args, **kwargs):
        super(QuantizableBasicBlock, self).__init__(*args, **kwargs)
        self.add_relu = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        shapes = (16, 32, 64, 128)
        block = QuantizableBasicBlock
        self.inplanes = shapes[0]
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False, dilation=1)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, shapes[0], layers[0])
        self.layer2 = self._make_layer(block, shapes[1], layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, shapes[2], layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, shapes[3], layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)


def _replace_relu(module):
    reassign = {}
    for name, mod in module.named_children():
        _replace_relu(mod)
        if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
            reassign[name] = nn.ReLU(inplace=False)

    for key, value in reassign.items():
        module._modules[key] = value


class QuantizableResNet(ResNet):

    def __init__(self, *args, **kwargs):
        super(QuantizableResNet, self).__init__(*args, **kwargs)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
        for m in self.modules():
            if isinstance(m, (QuantizableBasicBlock, DecoderBlock)):
                m.fuse_model()


class Conv2dReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0,
                 stride=1, use_batchnorm=True, **batchnorm_params):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size,
                      stride=stride, padding=padding, bias=not (use_batchnorm), groups=1),
            nn.ReLU(),
        ]

        if use_batchnorm:
            layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

    def fuse_model(self):
        fuse_modules(self, ['block.0', 'block.1', 'block.2'], inplace=True)
        for m in self.modules():
            if type(m) == QuantizableBasicBlock:
                m.fuse_model()


class Up(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return F.upsample_bilinear(x, scale_factor=2)


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=True),
            Up(),
            Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=True),
        )

    def forward(self, x):
        return self.block(x)

    def fuse_model(self):
        for m in self.modules():
            if isinstance(m, Conv2dReLU):
                m.fuse_model()


class ResNetUnet(QuantizableResNet):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pretrained = False
        in_channels = (128, 64, 32, 16, 16)
        prefinal_channels = 16
        final_channels = 3
        self.block1 = DecoderBlock(in_channels[0], in_channels[1])
        self.block2 = DecoderBlock(in_channels[1], in_channels[2])
        self.block3 = DecoderBlock(in_channels[2], in_channels[3])
        self.block4 = DecoderBlock(in_channels[3], in_channels[4])
        self.block5 = DecoderBlock(in_channels[4], prefinal_channels)
        self.final_conv = nn.Conv2d(prefinal_channels, final_channels, kernel_size=(1, 1))
        self.linear = nn.Linear(128, 10)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.add = torch.nn.quantized.FloatFunctional()

        del self.fc

    def forward(self, x):
        x = self.quant(x)
        x, cls = self._forward_impl(x)
        x = self.dequant(x)
        cls = self.dequant(cls)
        return x, cls

    def _forward_impl(self, x):
        x0 = self.conv1(x)
        x0 = self.bn1(x0)
        x0 = self.relu(x0)
        x1 = self.maxpool(x0)
        x1 = self.layer1(x1)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        cls_out = self.linear(self.pool(x4).view(x4.size(0), -1))
        x4 = self.block1(x4)
        x3 = self.add.add(x3, x4)
        x3 = self.block2(x3)
        x2 = self.add.add(x2, x3)
        x2 = self.block3(x2)
        x1 = self.add.add(x1, x2)
        x1 = self.block4(x1)
        x0 = self.add.add(x0, x1)
        x0 = self.block5(x0)
        x0 = self.final_conv(x0)

        return x0, cls_out


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p") / 1e6)
    os.remove('temp.p')


def get_model():
    model = ResNetUnet(QuantizableBasicBlock, [2, 2, 2, 2])
    _replace_relu(model)

    return model


def quantize_model(model, backend):
    _dummy_input_data = torch.rand(1, 3, 224, 224)
    if backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported ")
    torch.backends.quantized.engine = backend
    model.eval()
    if backend == 'fbgemm':
        model.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.default_observer,
            weight=torch.quantization.default_per_channel_weight_observer)
    elif backend == 'qnnpack':
        model.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.default_observer,
            weight=torch.quantization.default_weight_observer)

    model.fuse_model()
    torch.quantization.prepare(model, inplace=True)
    model(_dummy_input_data)
    torch.quantization.convert(model, inplace=True)


def evaltime(model):
    from time import time
    res = 0
    n = 500
    for _ in range(n):
        t = time()
        with torch.no_grad():
            model(torch.ones(1, 3, 224, 224))
        res += time() - t
    return res / n


torch.set_num_threads(1)
model = get_model()
model.eval()
print('time/image, initial model', evaltime(model)) # 0.025
print_size_of_model(model) # Size (MB): 2.89155
quantize_model(model, 'qnnpack')
print('time/image, quantized model', evaltime(model)) # 0.101
print_size_of_model(model) # Size (MB): 0.740964


For QNNPACK, the time you measure is on a x86 CPU. I am not surer how fast QNNPACK quantized implementation is compared to float on x86.

Can you try with fbgemm on a desktop to see what you get?

Also, you should not reuse the same module (self.add.add) multiple times in forward. This is because each instance needs to collect its own statistics. This will lead to poor accuracy with the quantized model.

Thank you!
I’ve tried FBGEMM , but it seems it doens’t supported on most android devices (armeabi-v7a. arm64-v8a). Then I found this post FBGEMM with PyTorch Mobile pointing out that FBGEMM is only compatible with x86 achitecture. Do you think I should try FBGEEM

I’ll try out removing multiple self.add.add in forward