Got slow speed on quantized model with fbgemm on X86

PyTorch/Libtorch 1.4.0

ABOUT CNN:

  1. Make a model just like MobileNetV3
  2. Do post-training static quantization with fbgemm
  3. The model size is reduced to a quarter of the original, the inferring speed is reduce to a half of the original, and the CPU usage is about 2400%, that means the default OMP_NUM_THREADS is 24
  4. Do “export OMP_NUM_THREADS=1”, the inferring speed is increased to 3 times the original
  5. Do “export OMP_NUM_THREADS=6”, the inferring speed is closed to the original

After more testing, I found that the problem is in depth-wise conv where groups is not 1.
My question is “Is this normal?”

ABOUT RNN:

  1. Make a model with 2 LSTMs
  2. Do post-training dynamic quantization
  3. The model size is reduced to a quarter of the original, the inferring speed is no significantly changed

My question is “Is this normal?”

cc @dskhudia @Zafar

Hi @wizardk,

Is the original running with a single thread?

Hi @dskhudia,
I had tested it in 1 and 10 threads. Let’s just make it simple, test it in 1 thread and limit the OMP with 1. Here are the details of the experiment.

1.Install Pytorch 1.4, download Libtorch 1.4

2.Prepare JIT model in Pytorch

import torch
from torch import nn
import torch.quantization as Q

class TestConv(nn.Module):
    def __init__(self, q, dw, i_c, o_c):
        super(TestConv, self).__init__()
        self.lyr = nn.Sequential(
            nn.Conv2d(in_channels=i_c, out_channels=i_c, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False),
            nn.BatchNorm2d(num_features=i_c),
            nn.ReLU(inplace=False) if q else nn.ReLU6(inplace=True),
            nn.Conv2d(in_channels=i_c, out_channels=o_c, kernel_size=3, stride=1, padding=1, dilation=1, groups=i_c if dw else 1, bias=False),
            nn.BatchNorm2d(num_features=o_c),
            nn.ReLU(inplace=False) if q else nn.ReLU6(inplace=True),
            nn.Conv2d(in_channels=o_c, out_channels=o_c, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False),
            nn.BatchNorm2d(num_features=o_c),
            nn.ReLU(inplace=False) if q else nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        return self.lyr(x)

class TestCNN(nn.Module):
    def __init__(self, q, dw):
        super(TestCNN, self).__init__()
        self.q = q
        self.quant = Q.QuantStub()
        self.dequant = Q.DeQuantStub()
        i_c = 1
        self.cnn = []
        for _ in range(8):
            self.cnn.append(TestConv(q=q, dw=dw, i_c=i_c, o_c=i_c*2))
            i_c *= 2
        self.cnn = nn.Sequential(*self.cnn)

    def fuse_model(self):
        for m in self.modules():
            if type(m) == TestConv:
                Q.fuse_modules(m.lyr, ['0', '1', '2'], inplace=True)
                Q.fuse_modules(m.lyr, ['3', '4', '5'], inplace=True)
                Q.fuse_modules(m.lyr, ['6', '7', '8'], inplace=True)

    def forward(self, x):
        if self.q:
            x = self.quant(x)
        x = self.cnn(x)
        if self.q:
            x = self.dequant(x)
        return x

def q_test(dw):
    def _eval(m):
        m.eval()
        with torch.no_grad():
            for batch_idx in range(10):
                x = torch.randn(10, 1, 100, 100)
                y = m(x)

    print('\nno quantization\n')
    fm = TestCNN(q=False, dw=dw)
    torch.save(fm.state_dict(), 'float.{}.pt'.format('dw' if dw else 'cmn'))
    torch.jit.save(torch.jit.script(fm), 'jit.f.{}.pt'.format('dw' if dw else 'cmn'))

    print('\npost-training static quantization\n')
    qm = TestCNN(q=True, dw=dw)
    qm.load_state_dict(torch.load('float.{}.pt'.format('dw' if dw else 'cmn'), map_location='cpu'))
    qm.eval()
    qm.fuse_model()
    qm.qconfig = Q.get_default_qconfig('fbgemm')
    Q.prepare(qm, inplace=True)
    _eval(qm)  # calibration
    Q.convert(qm, inplace=True)
    torch.jit.save(torch.jit.script(qm), 'jit.q.{}.pt'.format('dw' if dw else 'cmn'))

q_test(dw=False)  # dump float and quant model without depthwise
q_test(dw=True)  # dump float and quant model with depthwise

3.Run JIT model in Libtorch

#include <torch/script.h>
#include <torch/torch.h>
#include <pthread.h>
#include <omp.h>
#include <algorithm>
#include <iostream>
#include <chrono>
#include <vector>
#include <numeric>

typedef struct t_s_param {
    torch::jit::script::Module * sess;
    int loop_cnt;
    int * ms, * min_ms, * max_ms;
} s_param;

torch::TensorOptions g_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false).device(torch::kCPU);

torch::jit::script::Module load(const char * model_file_name)
{
    torch::NoGradGuard no_guard;

    torch::jit::script::Module module = torch::jit::load(model_file_name);
    module.to(torch::kCPU);
    module.eval();

    torch::Tensor x = torch::randn({ 1, 1, 32, 100 }, g_options);
    std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
    torch::Tensor y = module.forward({x}).toTensor();
    std::chrono::milliseconds elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start);
    std::cout << "warmup " << elapsed.count() << std::endl;

    return module;
}

void * working_thread(void * param)
{
    torch::init_num_threads();

    int * ms = ((s_param *)param)->ms;
    int * min_ms = ((s_param *)param)->min_ms;
    int * max_ms = ((s_param *)param)->max_ms;
    for (int idx = 0; idx < ((s_param *)param)->loop_cnt; ++idx) {
        torch::NoGradGuard no_guard;
        torch::Tensor x = torch::randn({ 1, 1, 32, 1000 }, g_options);
        std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
        torch::Tensor y = ((s_param *)param)->sess->get_method("forward")({x}).toTensor();
        std::chrono::milliseconds elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start);
        int elapsed_ms = elapsed.count();
        *ms += elapsed_ms;
        if (*min_ms == 0 || *min_ms > elapsed_ms) { *min_ms = elapsed_ms; }
        if (*max_ms == 0 || *max_ms < elapsed_ms) { *max_ms = elapsed_ms; }
    }
    *ms /= ((s_param *)param)->loop_cnt;
    std::cout << "thread quit" << std::endl;
    return 0;
}

int main(int argc, char ** argv)
{
    if (argc != 2) { return 0; }

    omp_set_num_threads(1);
    torch::set_num_threads(1);
    torch::set_num_interop_threads(1);

    torch::jit::script::Module module = load(argv[1]);

    // create thread
    std::vector<int> ms(thread_cnt, 0);
    std::vector<int> min_ms(thread_cnt, 0);
    std::vector<int> max_ms(thread_cnt, 0);
    std::vector<s_param> param(thread_cnt);
    std::vector<pthread_t> thread_handle;
    for (int idx = 0; idx < thread_cnt; ++idx) {
        param[idx].sess = &module;
        param[idx].op_thread_cnt = op_thread_cnt;
        param[idx].loop_cnt = loop_cnt;
        param[idx].ms = &ms[idx];
        param[idx].min_ms = &min_ms[idx];
        param[idx].max_ms = &max_ms[idx];
        pthread_t sub_handle;
        pthread_create(&sub_handle, 0, working_thread, &param[idx]);
        thread_handle.push_back(sub_handle);
    }
    for (int idx = 0; idx < thread_cnt; ++idx) {
        pthread_join(thread_handle[idx], 0);
    }
    float mean_time = std::accumulate(ms.begin(), ms.end(), 0) / ms.size();
    float min_time = *std::min_element(min_ms.begin(), min_ms.end());
    float max_time = *std::max_element(max_ms.begin(), max_ms.end());
    std::cout << "mean time : " << mean_time << std::endl;
    std::cout << "min  time : " << min_time << std::endl;
    std::cout << "max  time : " << max_time << std::endl;
    
    return 0;
}

4.Experiment result

Run float model without depthwise:
mean time : 648
min time : 642
max time : 805

Run quant model without depthwise:
mean time : 478
min time : 474
max time : 533

Run float model with depthwise:
mean time : 422
min time : 376
max time : 608

Run quant model with depthwise:
mean time : 1731
min time : 1725
max time : 1828