Hi @dskhudia,
I had tested it in 1 and 10 threads. Let’s just make it simple, test it in 1 thread and limit the OMP with 1. Here are the details of the experiment.
1.Install Pytorch 1.4, download Libtorch 1.4
2.Prepare JIT model in Pytorch
import torch
from torch import nn
import torch.quantization as Q
class TestConv(nn.Module):
def __init__(self, q, dw, i_c, o_c):
super(TestConv, self).__init__()
self.lyr = nn.Sequential(
nn.Conv2d(in_channels=i_c, out_channels=i_c, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False),
nn.BatchNorm2d(num_features=i_c),
nn.ReLU(inplace=False) if q else nn.ReLU6(inplace=True),
nn.Conv2d(in_channels=i_c, out_channels=o_c, kernel_size=3, stride=1, padding=1, dilation=1, groups=i_c if dw else 1, bias=False),
nn.BatchNorm2d(num_features=o_c),
nn.ReLU(inplace=False) if q else nn.ReLU6(inplace=True),
nn.Conv2d(in_channels=o_c, out_channels=o_c, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False),
nn.BatchNorm2d(num_features=o_c),
nn.ReLU(inplace=False) if q else nn.ReLU6(inplace=True),
)
def forward(self, x):
return self.lyr(x)
class TestCNN(nn.Module):
def __init__(self, q, dw):
super(TestCNN, self).__init__()
self.q = q
self.quant = Q.QuantStub()
self.dequant = Q.DeQuantStub()
i_c = 1
self.cnn = []
for _ in range(8):
self.cnn.append(TestConv(q=q, dw=dw, i_c=i_c, o_c=i_c*2))
i_c *= 2
self.cnn = nn.Sequential(*self.cnn)
def fuse_model(self):
for m in self.modules():
if type(m) == TestConv:
Q.fuse_modules(m.lyr, ['0', '1', '2'], inplace=True)
Q.fuse_modules(m.lyr, ['3', '4', '5'], inplace=True)
Q.fuse_modules(m.lyr, ['6', '7', '8'], inplace=True)
def forward(self, x):
if self.q:
x = self.quant(x)
x = self.cnn(x)
if self.q:
x = self.dequant(x)
return x
def q_test(dw):
def _eval(m):
m.eval()
with torch.no_grad():
for batch_idx in range(10):
x = torch.randn(10, 1, 100, 100)
y = m(x)
print('\nno quantization\n')
fm = TestCNN(q=False, dw=dw)
torch.save(fm.state_dict(), 'float.{}.pt'.format('dw' if dw else 'cmn'))
torch.jit.save(torch.jit.script(fm), 'jit.f.{}.pt'.format('dw' if dw else 'cmn'))
print('\npost-training static quantization\n')
qm = TestCNN(q=True, dw=dw)
qm.load_state_dict(torch.load('float.{}.pt'.format('dw' if dw else 'cmn'), map_location='cpu'))
qm.eval()
qm.fuse_model()
qm.qconfig = Q.get_default_qconfig('fbgemm')
Q.prepare(qm, inplace=True)
_eval(qm) # calibration
Q.convert(qm, inplace=True)
torch.jit.save(torch.jit.script(qm), 'jit.q.{}.pt'.format('dw' if dw else 'cmn'))
q_test(dw=False) # dump float and quant model without depthwise
q_test(dw=True) # dump float and quant model with depthwise
3.Run JIT model in Libtorch
#include <torch/script.h>
#include <torch/torch.h>
#include <pthread.h>
#include <omp.h>
#include <algorithm>
#include <iostream>
#include <chrono>
#include <vector>
#include <numeric>
typedef struct t_s_param {
torch::jit::script::Module * sess;
int loop_cnt;
int * ms, * min_ms, * max_ms;
} s_param;
torch::TensorOptions g_options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false).device(torch::kCPU);
torch::jit::script::Module load(const char * model_file_name)
{
torch::NoGradGuard no_guard;
torch::jit::script::Module module = torch::jit::load(model_file_name);
module.to(torch::kCPU);
module.eval();
torch::Tensor x = torch::randn({ 1, 1, 32, 100 }, g_options);
std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
torch::Tensor y = module.forward({x}).toTensor();
std::chrono::milliseconds elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start);
std::cout << "warmup " << elapsed.count() << std::endl;
return module;
}
void * working_thread(void * param)
{
torch::init_num_threads();
int * ms = ((s_param *)param)->ms;
int * min_ms = ((s_param *)param)->min_ms;
int * max_ms = ((s_param *)param)->max_ms;
for (int idx = 0; idx < ((s_param *)param)->loop_cnt; ++idx) {
torch::NoGradGuard no_guard;
torch::Tensor x = torch::randn({ 1, 1, 32, 1000 }, g_options);
std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
torch::Tensor y = ((s_param *)param)->sess->get_method("forward")({x}).toTensor();
std::chrono::milliseconds elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start);
int elapsed_ms = elapsed.count();
*ms += elapsed_ms;
if (*min_ms == 0 || *min_ms > elapsed_ms) { *min_ms = elapsed_ms; }
if (*max_ms == 0 || *max_ms < elapsed_ms) { *max_ms = elapsed_ms; }
}
*ms /= ((s_param *)param)->loop_cnt;
std::cout << "thread quit" << std::endl;
return 0;
}
int main(int argc, char ** argv)
{
if (argc != 2) { return 0; }
omp_set_num_threads(1);
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::jit::script::Module module = load(argv[1]);
// create thread
std::vector<int> ms(thread_cnt, 0);
std::vector<int> min_ms(thread_cnt, 0);
std::vector<int> max_ms(thread_cnt, 0);
std::vector<s_param> param(thread_cnt);
std::vector<pthread_t> thread_handle;
for (int idx = 0; idx < thread_cnt; ++idx) {
param[idx].sess = &module;
param[idx].op_thread_cnt = op_thread_cnt;
param[idx].loop_cnt = loop_cnt;
param[idx].ms = &ms[idx];
param[idx].min_ms = &min_ms[idx];
param[idx].max_ms = &max_ms[idx];
pthread_t sub_handle;
pthread_create(&sub_handle, 0, working_thread, ¶m[idx]);
thread_handle.push_back(sub_handle);
}
for (int idx = 0; idx < thread_cnt; ++idx) {
pthread_join(thread_handle[idx], 0);
}
float mean_time = std::accumulate(ms.begin(), ms.end(), 0) / ms.size();
float min_time = *std::min_element(min_ms.begin(), min_ms.end());
float max_time = *std::max_element(max_ms.begin(), max_ms.end());
std::cout << "mean time : " << mean_time << std::endl;
std::cout << "min time : " << min_time << std::endl;
std::cout << "max time : " << max_time << std::endl;
return 0;
}
4.Experiment result
Run float model without depthwise:
mean time : 648
min time : 642
max time : 805
Run quant model without depthwise:
mean time : 478
min time : 474
max time : 533
Run float model with depthwise:
mean time : 422
min time : 376
max time : 608
Run quant model with depthwise:
mean time : 1731
min time : 1725
max time : 1828