Taking advantage of JIT optimization from C++

Hello! I’m relatively new to pytorch, so this question may not make complete sense.

We are using a primarily C++ code base, so we are considering defining our models in C++. We are finding that for a toy example (GELUs), the python code which is converted to torchscript and loaded in C++ is significantly faster (2-6x) than implementing the same operation directly in C++. I have since learned that this is due to the JIT interpreter fusing certain operations into a single CUDA kernel.

Is this surprising? Is there anyway to take advantage of these optimizations from a model defined in C++?

You could check what exactly is fused in the operation by printing the graph via .graph_for for this module. This should show the fused operations and would give you more information about the difference in your timing.

I also assume you’ve written a custom CUDA implementation or are you comparing the GPU kernel vs. your C++ (host) code?