Does JIT makes model faster?

Is there any JIT performance measurements? Does it makes a model any faster or the only benefit of involving JIT is ability to save model and perform inference in any other environment except python?

Yes, we do monitor the performance of certain bits. For example the recent PyTorch blog on RNN speedups discusses benchmarks we’ve been monitoring quite closely and continue to work against. ResNet performance it also regularly checked.

That said, whether any given model sees significant speedups, depends.

  • I always give the ballpark figure of 10% speedup for moving from Python to C++ - I got this number from a couple of specific models, e.g. when you do a “1-1” translation into C++ of the LLTM model used in the C+±Extension tutorial. Your model will see different numbers. A similar speedup probably is there for the JIT.
  • Where the JIT really get large speedups is when one of the optimizations can fully come into play. E.g. if you have chains of elementwise operations, they will be fused into a single kernel. As those are typically memory-bound, fusing two elementwise ops will be ~2x as fast as doing them separately.

Best regards

Thomas

1 Like

I traced the BERT model from HuggingFace PyTorchTransformers library and getting following results for 10 iterations.
a) Using Python runtime for running the forward: 979292 µs

import time
model = torch.jit.load('models_backup/2_2.pt')
x = torch.randint(2000, (1, 14), dtype=torch.long, device='cpu')
start = time.time()
for i in range(10):
    model(x)
end = time.time()
print((end - start)*1000000, "µs")

b) Using C++ runtime for running the forward: 3333758 µs which is almost 3x of what Python

  torch::Tensor x = torch::randint(index_max, {1, inputsize}, torch::dtype(torch::kInt64).device(torch::kCPU));
  input.push_back(x);
  #endif
  // Execute the model and turn its output into a tensor.
  auto outputs = module->forward(input).toTuple();
  auto start = chrono::steady_clock::now();
  for (int16_t i = 0; i<10; ++i)
  {
    outputs = module->forward(input).toTuple();
  }
  auto end = chrono::steady_clock::now();
  cout << "Elapsed time in microseconds : " 
		<< chrono::duration_cast<chrono::microseconds>(end - start).count()
		<< " µs" << endl;

@tom any suggestions on what am I missing ?

You are not even doing the comparison I had in mind. - If the C++/uses the JIT, you compare JIT called from Python vs JIT called from C++, and that should really have the same speed modulo constant overhead (which is not 6s).
Are you using the same inputs, libtorch, environment,…?

Best regards

Thomas

Hi, Tom
Which layers can be seen as chains of elementwise ops? Any detailed benchmarks? Thank you very much in advance.

Best regards,

Edward

I teach that in my PyTorch internals training, if you’re near Munich and want to book a seat… :slight_smile:
But so the theory answer is any sequence of elementwise ops and the practical answer is anything that you see merged into fusion groups in myfn.graph_for(*my_inputs) . (Only is done on GPU by default.)
In addition to the blog post linked above, there is my blog post on optimizing LSTM backwards and an old talk from me using this on a simple example IoU in detail.
Obviously a lot more is to be had from extending what can be fused TorchTVM is a (highly experimental) approach to hook into TVM which is a (also experimental in my experience) framework that can optimize also reductions, which the JIT cannot .
Personally, I think a lot more could be had, but I’m not sure who makes that a top priority at the moment.

Best regards

Thomas