TorchLib inference works slower than ONNX and JIT

Hi, I’m trying to implement C++ inference for huggingface bert (mini). Here’s are my code snippets to measure elapsed time:

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>

using namespace std::chrono;
using namespace std;

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok\n";

  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(torch::ones({1, 150}, at::kLong));
  torch::NoGradGuard no_grad;
  auto start = high_resolution_clock::now();
  auto output = module.forward(inputs);
  auto stop = high_resolution_clock::now();
  auto duration = duration_cast<milliseconds>(stop - start);

  cout << duration.count() << std::endl;
  std::cout << output;
}
import torch
import numpy as np
from time import time
import onnxruntime as onnx_rt

onnx_rt = onnx_rt.InferenceSession('clf_mini.onnx', providers=['CPUExecutionProvider'])
model = torch.jit.load('jit_mini.pt')
model.eval()
inp = torch.zeros((1, 150), dtype=torch.int64)

N = 1000
times = []
for i in range(N):
    t0 = time()
    r = onnx_rt.run(None, {'input_values': np.zeros((1, 150), dtype=np.int64)})[0]
    dt = time()-t0
    times.append(dt)
print('ONNX Runtime:', 1000 * sum(times) / N, 'ms')

N = 1000
times = []
for i in range(N):
    t0 = time()
    r = model(torch.zeros((1, 150), dtype=torch.int64))
    dt = time()-t0
    times.append(dt)
print('Torch JIT Runtime:', 1000 * sum(times) / N, 'ms')

Results: ONNX and JIT ~5 ms, TorchLib ~70 ms.

Is this expected behavior or I did something wrong?

I found the problem: it seems that at the first 1-2 forwards in C++ torchlib performs some kind of warmup or benchmarking, so it takes longer. After that, it has the same execution time as python.

1 Like