Slow quantization

I’ve tried to quantize a simple model with conv+bn+relu combination but it performs much slower in int8.
Am I missing something here?

Code To Reproduce

import os
import time

import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub

backend = 'qnnpack'
# backend = 'fbgemm'
import torch
torch.backends.quantized.engine = backend


class DownBlockQ(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.quant_input = QuantStub()
        self.dequant_output = DeQuantStub()

        self.conv1 = nn.Conv2d(in_ch, in_ch, 4, stride=2, padding=1, groups=in_ch)
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(in_ch, out_ch, 1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        # x = self.quant_input(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        # x = self.dequant_output(x)
        return x

    def fuse_model(self):
        torch.quantization.fuse_modules(self, ['conv1', 'bn1', 'relu1'], inplace=True)
        torch.quantization.fuse_modules(self, ['conv2', 'bn2', 'relu2'], inplace=True)


class Model(nn.Module):
    def __init__(self, filters=22):
        super().__init__()
        self.quant_input = QuantStub()
        self.dequant_output = DeQuantStub()

        self.db1 = DownBlockQ(filters * 1, filters * 2)  # 128
        self.db2 = DownBlockQ(filters * 2, filters * 4)  # 64
        self.db3 = DownBlockQ(filters * 4, filters * 8)  # 32

    def forward(self, x):
        x = self.quant_input(x)
        x = self.db1(x)
        x = self.db2(x)
        x = self.db3(x)
        x = self.dequant_output(x)
        return x


def fuse_model(model):
    if hasattr(model, 'fuse_model'):
        model.fuse_model()

    for p in list(model.modules())[1:]:
        fuse_model(p)


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p") / 1e6)
    os.remove('temp.p')


def benchmark(func, iters=10, *args):
    t1 = time.time()
    for _ in range(iters):
        res = func(*args)
    print(f'{((time.time() - t1) / iters):.6f} sec')
    return res


def quantize():
    dummy = torch.rand(1, 22, 256, 256)
    # model = DownBlockQ(22 * 1, 22 * 2)
    model = Model(filters=22)
    model = model.eval()
    print("Before quantization")
    print_size_of_model(model)

    benchmark(model, 20, dummy)
    # print(model)
    fuse_model(model)

    model.qconfig = torch.quantization.get_default_qconfig(backend)
    # print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)

    # print(model)
    print("After quantization")
    print_size_of_model(model)
    benchmark(model, 20, dummy)
    # torch.jit.script(model).save('models/model_scripted.pt')


if __name__ == '__main__':
    quantize()

Expected behavior

Int8 model to be 2-3 times faster than float32.

Environment

PyTorch version: 1.7.0.dev20200727
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 20.04 LTS
GCC version: (Ubuntu 8.4.0-3ubuntu2) 8.4.0
CMake version: version 3.16.3

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: GeForce GTX 1070
Nvidia driver version: 440.100
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.19.0
[pip3] torch==1.7.0.dev20200727
[pip3] torchvision==0.8.0.dev20200727
[conda] Could not collect

Thanks for flagging, the input sizes to the conv layers seem a bit unconventional so I’m wondering if that is causing a slowdown. Are these sizes part of an actual model?
cc @dskhudia

I tried printing the model

Model(
  (quant_input): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant_output): DeQuantize()
  (db1): DownBlockQ(
    (quant_input): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
    (dequant_output): DeQuantize()
    (conv1): QuantizedConvReLU2d(22, 22, kernel_size=(4, 4), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), groups=22)
    (bn1): Identity()
    (relu1): Identity()
    (conv2): QuantizedConvReLU2d(22, 44, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
    (bn2): Identity()
    (relu2): Identity()
  )
  (db2): DownBlockQ(
    (quant_input): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
    (dequant_output): DeQuantize()
    (conv1): QuantizedConvReLU2d(44, 44, kernel_size=(4, 4), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), groups=44)
    (bn1): Identity()
    (relu1): Identity()
    (conv2): QuantizedConvReLU2d(44, 88, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
    (bn2): Identity()
    (relu2): Identity()
  )
  (db3): DownBlockQ(
    (quant_input): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
    (dequant_output): DeQuantize()
    (conv1): QuantizedConvReLU2d(88, 88, kernel_size=(4, 4), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), groups=88)
    (bn1): Identity()
    (relu1): Identity()
    (conv2): QuantizedConvReLU2d(88, 176, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
    (bn2): Identity()
    (relu2): Identity()
  )
)

I think kernel size = 4 and dewpthwise conv is the culprit here. Quanatized depthwise is optimized mainly for common kernel sizes 3 and 5. Just to reiterate @supriyar’s question: Is there any reason to use kernel size 4?

@dskhudia @supriyar Thanks for your replies! Yes, it is a part of an actual model. So it is very undesirable to change it. I’ve tried convs with kernels 3 and 5 but even with such config int8 slower than float32.

For kernel sizes 3 and 5 I can take a look at it to see why it’s slow.

1 Like

Please, take a look:

import os
import time

import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub

# backend = 'qnnpack'
backend = 'fbgemm'
import torch

torch.backends.quantized.engine = backend


class DownBlockQ(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1, groups=in_ch)
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(in_ch, out_ch, 1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

    def fuse_model(self):
        torch.quantization.fuse_modules(self, ['conv1', 'bn1', 'relu1'], inplace=True)
        torch.quantization.fuse_modules(self, ['conv2', 'bn2', 'relu2'], inplace=True)


class Model(nn.Module):
    def __init__(self, filters=22, quant=True):
        super().__init__()
        self.quant = quant
        self.quant_input = QuantStub()
        self.dequant_output = DeQuantStub()

        self.db1 = DownBlockQ(filters * 1, filters * 2)  # 128
        self.db2 = DownBlockQ(filters * 2, filters * 4)  # 64
        self.db3 = DownBlockQ(filters * 4, filters * 8)  # 32

    def forward(self, x):
        if self.quant:
            x = self.quant_input(x)
        x = self.db1(x)
        x = self.db2(x)
        x = self.db3(x)
        if self.quant:
            x = self.dequant_output(x)
        return x


def fuse_model(model):
    if hasattr(model, 'fuse_model'):
        model.fuse_model()

    for p in list(model.modules())[1:]:
        fuse_model(p)


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p") / 1e6)
    os.remove('temp.p')


def benchmark(func, iters=10, *args):
    t1 = time.time()
    for _ in range(iters):
        res = func(*args)
    print(f'{((time.time() - t1) / iters):.6f} sec')
    return res


def quantize():
    dummy = torch.rand(1, 22, 256, 256)
    model = Model(filters=22, quant=False).eval()
    print("Before quantization")
    print_size_of_model(model)
    benchmark(model, 20, dummy)

    # print(model)
    model = Model(filters=22, quant=True).eval()
    fuse_model(model)

    model.qconfig = torch.quantization.get_default_qconfig(backend)
    # print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)

    print("After quantization")
    print_size_of_model(model)
    benchmark(model, 20, dummy)


if __name__ == '__main__':
    quantize()

Hi @dklvch,

Int8 depthwise convolution is very slow when filters is not a multiple of 8. Could you try with filters = 16 or 24?

dummy = torch.rand(1, 22, 256, 256) => dummy = torch.rand(1, 24, 256, 256)
model = Model(filters=22, quant=False).eval() => model = Model(filters=24, quant=False).eval()
model = Model(filters=22, quant=True).eval() => model = Model(filters=24, quant=True).eval()