INT8 quantized model is much slower than fp32 model on CPU

Hi, all

I finally success converting the fp32 model to the int8 model thanks to pytorch forum community :slight_smile:.
In order to make sure that the model is quantized, I checked that the size of my quantized model is smaller than the fp32 model (500MB->130MB).
However, operating my quantized model is much slower than operating the fp32 model. (700ms -> 2.4s)

I converted pre-trained VGG16 model in torchvision.models.
I am working on Nvidia JetsonTx2, and I checked that quantized mobilenet in torchvision.models.quantization.mobilenet is much faster than fp32 mobilenet model.
So I think that my conversion work might be wrong.

If you guys need more information, please let me know.

This is an output of “print(quantized_model)”.

RecursiveScriptModule(
  original_name=VGG
  (features): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=ReLU)
    (2): RecursiveScriptModule(original_name=Conv2d)
    (3): RecursiveScriptModule(original_name=ReLU)
    (4): RecursiveScriptModule(original_name=MaxPool2d)
    (5): RecursiveScriptModule(original_name=Conv2d)
    (6): RecursiveScriptModule(original_name=ReLU)
    (7): RecursiveScriptModule(original_name=Conv2d)
    (8): RecursiveScriptModule(original_name=ReLU)
    (9): RecursiveScriptModule(original_name=MaxPool2d)
    (10): RecursiveScriptModule(original_name=Conv2d)
    (11): RecursiveScriptModule(original_name=ReLU)
    (12): RecursiveScriptModule(original_name=Conv2d)
    (13): RecursiveScriptModule(original_name=ReLU)
    (14): RecursiveScriptModule(original_name=Conv2d)
    (15): RecursiveScriptModule(original_name=ReLU)
    (16): RecursiveScriptModule(original_name=MaxPool2d)
    (17): RecursiveScriptModule(original_name=Conv2d)
    (18): RecursiveScriptModule(original_name=ReLU)
    (19): RecursiveScriptModule(original_name=Conv2d)
    (20): RecursiveScriptModule(original_name=ReLU)
    (21): RecursiveScriptModule(original_name=Conv2d)
    (22): RecursiveScriptModule(original_name=ReLU)
    (23): RecursiveScriptModule(original_name=MaxPool2d)
    (24): RecursiveScriptModule(original_name=Conv2d)
    (25): RecursiveScriptModule(original_name=ReLU)
    (26): RecursiveScriptModule(original_name=Conv2d)
    (27): RecursiveScriptModule(original_name=ReLU)
    (28): RecursiveScriptModule(original_name=Conv2d)
    (29): RecursiveScriptModule(original_name=ReLU)
    (30): RecursiveScriptModule(original_name=MaxPool2d)
  )
  (avgpool): RecursiveScriptModule(original_name=AdaptiveAvgPool2d)
  (classifier): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=Linear
      (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
    )
    (1): RecursiveScriptModule(original_name=ReLU)
    (2): RecursiveScriptModule(original_name=Dropout)
    (3): RecursiveScriptModule(
      original_name=Linear
      (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
    )
    (4): RecursiveScriptModule(original_name=ReLU)
    (5): RecursiveScriptModule(original_name=Dropout)
    (6): RecursiveScriptModule(
      original_name=Linear
      (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
    )
  )
  (quant): RecursiveScriptModule(original_name=Quantize)
  (dequant): RecursiveScriptModule(original_name=DeQuantize)
)

The following code is the conversion code that I wrote.

from torch.quantization import QuantStub, DeQuantStub

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

# # Setup warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)


from torch.hub import load_state_dict_from_url


__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',
]


model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}


class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def vgg11(pretrained=False, progress=True, **kwargs):
    r"""VGG 11-layer model (configuration "A") from
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)


def vgg11_bn(pretrained=False, progress=True, **kwargs):
    r"""VGG 11-layer model (configuration "A") with batch normalization
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)


def vgg13(pretrained=False, progress=True, **kwargs):
    r"""VGG 13-layer model (configuration "B")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)


def vgg13_bn(pretrained=False, progress=True, **kwargs):
    r"""VGG 13-layer model (configuration "B") with batch normalization
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)


def vgg16(pretrained=False, progress=True, **kwargs):
    r"""VGG 16-layer model (configuration "D")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)


def vgg16_bn(pretrained=False, progress=True, **kwargs):
    r"""VGG 16-layer model (configuration "D") with batch normalization
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)


def vgg19(pretrained=False, progress=True, **kwargs):
    r"""VGG 19-layer model (configuration "E")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)


def vgg19_bn(pretrained=False, progress=True, **kwargs):
    r"""VGG 19-layer model (configuration 'E') with batch normalization
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

def load_model(model_file):
    if model_file is None:
        model = vgg16(pretrained=True)
    if not model_file is None:
        model = vgg16()
        state_dict = torch.load(model_file)
        model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

'''
import requests

url = 'https://s3.amazonaws.com/pytorch-tutorial-assets/imagenet_1k.zip'
filename = '~/Downloads/imagenet_1k_data.zip'

r = requests.get(url)

with open(filename, 'wb') as f:
    f.write(r.content)
'''
import torchvision
import torchvision.transforms as transforms
'''
imagenet_dataset = torchvision.datasets.ImageNet(
    'data/imagenet_1k',
    split='train',
    download=True,
    transform=transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]))
'''

def prepare_data_loaders(data_path):

    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset = torchvision.datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

data_path = 'data/imagenet_1k'
saved_model_dir = 'data/'
scripted_float_model_file = 'vgg16_quantization_scripted.pth'
scripted_quantized_model_file = 'vgg16_quantization_scripted_quantized.pth'

train_batch_size = 30
eval_batch_size = 30

data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(None).to('cpu')

float_model.eval()

num_eval_batches = 10

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)


num_calibration_batches = 10

per_channel_quantized_model = load_model(None).to('cpu')
per_channel_quantized_model.eval()

torch.backends.quantized.engine = 'qnnpack'
per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')

print(per_channel_quantized_model.qconfig)

torch.quantization.prepare(per_channel_quantized_model, inplace=True)
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches)
torch.quantization.convert(per_channel_quantized_model, inplace=True)
top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file)

I printed your quantized model def before scripting: https://gist.github.com/vkuzo/edb2121a757d5789977935ad56820a24

One improvement would be to fuse subsequent Conv-ReLU modules together, so they can use the faster fused quantized kernel:

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(4, 4, 1),
    nn.ReLU(),
)

# Fuse each Conv and ReLU (implement this for your model)
torch.quantization.fuse_modules(model, [['0', '1']], inplace=True)
print(model)

# prepare
torch.backends.quantized.engine = 'qnnpack'
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare(model, inplace=True)

# calibrate (toy example)
input_data = torch.randn(4, 4, 4, 4)
model(input_data)

# convert
torch.quantization.convert(model, inplace=True)

# should see QuantizedConvReLU2d module
print(model)

If you still see a performance gap after this, might be good to check if QNNPACK is enabled on your target device.

1 Like

and, you can also fuse Linear + ReLU

Thank you for your reply.

Even I used ‘fuse’ for conv+relu and linear+relu but there is no speed improvement.
The QNNPACK is well enabled because I checked quantized mobilenet in torchvision.models.quantization.mobilenet which uses qnnpack backend is faster than fp32 model.

Could you suggest another feasible solution?

maybe you should check your threads num, and use torch.set_num_threads(1)

it could also be related to op support in QNNPACK. PyTorch has a fork of QNNPACK which lives here (https://github.com/pytorch/pytorch/tree/172f31171a3395cc299044e06a9665fec676ddd6/aten/src/ATen/native/quantized/cpu/qnnpack), and the readme contains the supported ops.

Your model has a few modules which are not supported, which means they would still run but there aren’t fast ARM kernels: AdaptiveAvgPool2d, and Dropout. Just for debugging’s sake, you could check if removing these modules or replacing them with alternatives which are optimized for ARM fixes the speed issue

can you print your model right before scripting it and verify you get this (https://gist.github.com/vkuzo/edb2121a757d5789977935ad56820a24) ?

This is the output of the model before scripted.

VGG(
  (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.015086950734257698, zero_point=2, padding=(1, 1))
  (relu1): Identity()
  (conv2): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.005462500732392073, zero_point=0, padding=(1, 1))
  (relu2): Identity()
  (conv3): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.002446091501042247, zero_point=0, padding=(1, 1))
  (relu3): Identity()
  (conv4): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.0008910637116059661, zero_point=1, padding=(1, 1))
  (relu4): Identity()
  (conv5): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.0006946324720047414, zero_point=1, padding=(1, 1))
  (relu5): Identity()
  (conv6): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.0002671453694347292, zero_point=1, padding=(1, 1))
  (relu6): Identity()
  (conv7): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.00013638826203532517, zero_point=3, padding=(1, 1))
  (relu7): Identity()
  (conv8): QuantizedConvReLU2d(256, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.00012979305756743997, zero_point=0, padding=(1, 1))
  (relu8): Identity()
  (conv9): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.00012682013039011508, zero_point=1, padding=(1, 1))
  (relu9): Identity()
  (conv10): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=8.234349661506712e-05, zero_point=1, padding=(1, 1))
  (relu10): Identity()
  (conv11): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=9.820431296247989e-05, zero_point=0, padding=(1, 1))
  (relu11): Identity()
  (conv12): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=8.165000326698646e-05, zero_point=0, padding=(1, 1))
  (relu12): Identity()
  (conv13): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=8.769309351919219e-05, zero_point=0, padding=(1, 1))
  (relu13): Identity()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (fc1): QuantizedLinearReLU(
    in_features=25088, out_features=4096, scale=6.691644375678152e-05, zero_point=0
    (_packed_params): LinearPackedParams()
  )
  (relu14): Identity()
  (fc2): QuantizedLinearReLU(
    in_features=4096, out_features=4096, scale=8.03592411102727e-05, zero_point=0
    (_packed_params): LinearPackedParams()
  )
  (relu15): Identity()
  (fc3): QuantizedLinear(
    in_features=4096, out_features=1000, scale=0.0001865544618340209, zero_point=131
    (_packed_params): LinearPackedParams()
  )
  (softmax): Softmax(dim=1)
  (quant): Quantize(scale=tensor([0.0186]), zero_point=tensor([114]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

In my model, there is a lot of Identity() layer. I thought that these layers are generated by fuse() function. I don’t think that these affect the performance. Do they affect the latency performance of execution?

thanks. The Identity layers do not do anything and shouldn’t contribute to performance. Your model def after quantization looks right. Unfortunately we don’t have a JetsonX2 so we can’t check locally, and your setup looks right. At this point might be good to try and bisect the issue - check if any particular layers are slow (in particular, ones not supported by QNNPACK).

FYI, in order to investigate the bottleneck of model execution, I profiled my quantized model using torch.autograd.profiler.profile().
I thought there is some problem about quantized::conv2d which QNNPACK supports.

I will share the further progress, thank you so much @Vasiliy_Kuznetsov ! :slight_smile:

 ---------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                         Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
---------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
quantized::conv2d            92.64%           2.221s           92.67%           2.222s           170.929ms        13
quantized::linear            5.69%            136.383ms        5.69%            136.452ms        45.484ms         3
_adaptive_avg_pool2d         0.64%            15.370ms         0.64%            15.370ms         15.370ms         1
relu_                        0.49%            11.650ms         0.49%            11.650ms         776.697us        15
quantized_max_pool2d         0.40%            9.491ms          0.40%            9.491ms          1.898ms          5
quantize_per_tensor          0.09%            2.261ms          0.09%            2.261ms          2.261ms          1
contiguous                   0.02%            410.239us        0.02%            450.143us        28.134us         16
_empty_affine_quantized      0.01%            304.640us        0.01%            304.640us        17.920us         17
max                          0.01%            160.672us        0.01%            160.672us        160.672us        1
q_scale                      0.00%            113.983us        0.00%            113.983us        2.478us          46
clone                        0.00%            102.719us        0.00%            102.719us        102.719us        1
dequantize                   0.00%            50.016us         0.00%            50.016us         50.016us         1
q_zero_point                 0.00%            45.728us         0.00%            45.728us         1.524us          30
view                         0.00%            44.704us         0.00%            44.704us         44.704us         1
max_pool2d                   0.00%            36.128us         0.40%            9.527ms          1.905ms          5
select                       0.00%            31.680us         0.00%            31.680us         31.680us         1
reshape                      0.00%            30.208us         0.01%            195.071us        97.535us         2
_unsafe_view                 0.00%            17.440us         0.00%            17.440us         17.440us         1
empty_like                   0.00%            13.888us         0.00%            39.904us         39.904us         1
_local_scalar_dense          0.00%            13.504us         0.00%            13.504us         4.501us          3
is_floating_point            0.00%            13.440us         0.00%            13.440us         13.440us         1
item                         0.00%            12.448us         0.00%            25.952us         8.651us          3
flatten                      0.00%            7.456us          0.01%            138.463us        138.463us        1
adaptive_avg_pool2d          0.00%            5.312us          0.64%            15.376ms         15.376ms         1
dropout                      0.00%            5.216us          0.00%            5.216us          2.608us          2
qscheme                      0.00%            4.736us          0.00%            4.736us          4.736us          1
is_complex                   0.00%            3.360us          0.00%            3.360us          3.360us          1
sizes                        0.00%            2.656us          0.00%            2.656us          2.656us          1
size                         0.00%            2.240us          0.00%            2.240us          2.240us          1
---------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------

@Jungmo_Ahn I met the same problem, have you solve it?

I am still trying to solve it. If there is meaningful result, I will share it here.

@Jungmo_Ahn Have you solved this? I met the same problem. Thanks.

If you use the x86 machine like intel, I suggest you to use “fbgemm”, if you use arm that what qnnpack helps