No performance improvement using quantization model in pytorch

I have trained a model in pytorch with float data type. I want to improve my inference time by converting this model to quantized model. I have used torch.quantization.convert and torch.quantization.quantize_dynamic api to convert my model’s weight to uint8 data type. However, when I use this model for inference, I do not get any performance improvement. Am I doing something wrong here ?

The Unet Model code:

def gen_initialization(m):
    if type(m) == nn.Conv2d:
        sh = m.weight.shape
        nn.init.normal_(m.weight, std=math.sqrt(2.0 / (sh[0]*sh[2]*sh[3])))
        nn.init.constant_(m.bias, 0)
    elif type(m) == nn.BatchNorm2d:
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

class TripleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(TripleConv, self).__init__()
        mid_ch = (in_ch + out_ch) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, mid_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(num_features=mid_ch),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(mid_ch, mid_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(num_features=mid_ch),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(num_features=out_ch),
            nn.LeakyReLU(negative_slope=0.1)
        )
        self.conv.apply(gen_initialization)

    def forward(self, x):
        return self.conv(x)


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.triple_conv = TripleConv(in_ch, out_ch)
        self.avg_pool_conv = nn.AvgPool2d(2, 2)
        self.in_ch = in_ch
        self.out_ch = out_ch

    def forward(self, x):
        self.cache = self.triple_conv(x)
        pad = torch.zeros(x.shape[0], self.out_ch - self.in_ch, x.shape[2], x.shape[3], device=x.device)
        x = torch.cat((x, pad), dim=1)
        self.cache += x
        return self.avg_pool_conv(self.cache)


class Center(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Center, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(num_features=out_ch),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )
        self.conv.apply(gen_initialization)

    def forward(self, x):
        return self.conv(x)


class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                    align_corners=True)
        self.triple_conv = TripleConv(in_ch, out_ch)

    def forward(self, x, cache):
        x = self.upsample(x)
        x = torch.cat((x, cache), dim=1)
        x = self.triple_conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_ch, first_ch=None):
        super(UNet, self).__init__()

        if not first_ch:
            first_ch = 32

        self.down1 = Down(in_ch, first_ch)
        self.down2 = Down(first_ch, first_ch*2)
        self.down3 = Down(first_ch*2, first_ch*4)
        self.down4 = Down(first_ch*4, first_ch*8)
        self.center = Center(first_ch*8, first_ch*8)
        self.up4 = Up(first_ch*8*2, first_ch*4)
        self.up3 = Up(first_ch*4*2, first_ch*2)
        self.up2 = Up(first_ch*2*2, first_ch)
        self.up1 = Up(first_ch*2, first_ch)
        self.output = nn.Conv2d(first_ch, in_ch, kernel_size=3, stride=1,
                                padding=1, bias=True)
        self.output.apply(gen_initialization)

    def forward(self, x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.center(x)
        x = self.up4(x, self.down4.cache)
        x = self.up3(x, self.down3.cache)
        x = self.up2(x, self.down2.cache)
        x = self.up1(x, self.down1.cache)
        return self.output(x)

The inference code:

from tqdm import tqdm
import os
import numpy as np
import torch
import gan_network
import torch.nn.parallel
from torch.utils.data import DataLoader
import torch.utils.data as data
import random
import glob
import scipy.io
import time
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"


class DataFolder(data.Dataset):
    def __init__(self, file):
        super(DataFolder, self).__init__()
        self.image_names = []
        fid = file
        for line in fid:
            # line = line[:-1]
            if line == '':
                continue
            # print(line)
            self.image_names.append(line)
        random.shuffle(self.image_names)
        self.image_names = self.image_names[0:]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        path = self.image_names[index]
        img = np.load(path)
        img = np.rollaxis(img, 2, 0)
        img = torch.from_numpy(img[:, :, :])
        return img, path


if __name__ == '__main__':
    batch_size = 1
    image_size = 2048
    channels = 6
    model_path = 'D:/WorkProjects/Network_Training_Aqusens/FullFovReconst/network/network_epoch9.pth'
    test_data = glob.glob('D:/save/temp/*.npy')
    dest_dir = 'D:/save/temp/results/'

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net = gan_network.UNet(6, 32)
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)
    net.to(device)
    net.load_state_dict(torch.load(model_path))
    quantized_model = torch.quantization.quantize_dynamic(net, {torch.nn.Conv2d, torch.nn.BatchNorm2d}, inplace=False)

    dataset = DataFolder(file=test_data)
    print(f'{len(dataset)}')
    data_loader = DataLoader(dataset=dataset, num_workers=4,
                             batch_size=batch_size, shuffle=False,
                             drop_last=False, pin_memory=True)
    input = torch.Tensor(batch_size, channels, image_size, image_size).to(device)

    t0 = time.time()
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            input.copy_(batch[0])
            output = net(input).cpu().clone().numpy()
            np.array(output)
            output = np.rollaxis(output, 1, 4)
            for num in range(batch_size):
                arr = output[num, :, :, :]
                file_name = os.path.basename(batch[1][num])
                save_name = os.path.join(dest_dir, file_name)
                save_name = save_name.replace(".npy", "")
                scipy.io.savemat(save_name+'.mat', {'output': arr})
    t1 = time.time()
    print(f'Elapsed time = {t1-t0}')

For models net and quantized_model, i get the elapsed time around 30 seconds for 12 images passed through them.

I can no longer find the requisite doc link, but Pytorch dynamic quantization is currently (v1.5) only provided for Linear and LSTM layers. A model that doesn’t have a high proportion of those will not benefit from dynamic quantization.

(To confirm, print() a quantized model to see which layers have been replaced.)

Hi @dnaik

In the quoted message I didn’t find that you did quantized the model with convert() API. If it is, then quantize the model with convert() API.
Refer to this Doc to quantize model with prepare & convert() APIs : Static quantization

Quantize_dynamic() is not applicable to your model because as of now it only supports to quantize lSTM and Linear layers.