Torchscript model is slow

I’ve made a pytorch model which uses a bit of dynamic python such as for looping where the number of loops is specified by entries of a tensor. (It’s similar to a graph neural net. I have since seen pytorch geometric, but it seems that those models can’t be exported to libtorch.) I believe that it is because of this that the model is very slow to train and run. I need to be able to export and use this model in C. I used torch.jit.script and successfully saved the model. I used libtorch to get it to C++ and a header file to make it so I can call the model from C.

The model is very slow to train in python. It was my understanding that a model should become much faster after running torch.jit.script on it (I thought that this turned slow python into fast C++). Would this model train faster if I trained it in C++? Why or why not? What if I just saved the model and reloaded it back into python?

I haven’t included the code here because I couldn’t find an easy way to reduce it to a minimal example that would be removed from the confusing specifics of my research domain. Sorry if this question is hard to grok.

I’m having a hard time understanding why pytorch models are slow or fast and how, in general, they can be made faster. If torchscript makes a dynamic graph static and able to run from C++, then why doesn’t it make a model much faster? (Or does it and I’m using it wrong?)

Thanks in advance,

Note that both, the Python and C++ frontend, call into the same backend methods.
In the usual case you would run your model on the GPU, so that Python is used for the model “logic” and the dispatching of the kernels (to C++ and CUDA).
For “big” workloads the Python overhead might be marginal.

If you have a static graph (with your conditions etc.), then some optimization might be applied such as fusing certain operations together.
However, it depends on your actual model, if fusing ops is possible. Also note, that this area is still being worked on.

Thanks for your speedy reply!

I understand that things like the forward method of a linear layer are using a fast backend. I’m wondering about the slow python code (like Python for-loops and other things that don’t have explicit pytorch functions). I’m wondering if such code is made faster through torch.jit.script.

Perhaps my point can be seen through the code below. If I had implemented the code below in C or C++, it would be much faster than it is here. Is torchscript not able to turn these slow python for-loops into fast C for-loops? If I ran this saved module after loading it into libtorch in C++, would it run significantly faster? If so, why can’t isn’t python taking advantage of these benefits? I may be missing the point entirely, but do you have any way of making code like the following fast using pytorch?

import torch
from torch import nn
from timeit import timeit

class TestModule(nn.Module):

    def __init__(self, m, n):
        self.m = m
        self.n = n

    def forward(self):
        val = 0
        for i in range(self.m):
            for j in range(self.n):
                val += i*j
        return val

def purePython(m,n):
    val = 0
    for i in range(m):
        for j in range(n):
            val += i*j
    return val

m = 400
n = 400

# regular pytorch module
model = TestModule(m,n)

# jit'ed module
jitModel = torch.jit.script(model)

# saved jit'ed module, "")
saved = torch.jit.load("")

numTrials = 800
print(f"pure python function time: {timeit('purePython(m,n)',  number=numTrials, globals=globals()  )}")
print(f"pytorch model time:        {timeit('model()',          number=numTrials, globals=globals()  )}")
print(f"pytorch jit model time:    {timeit('jitModel()',       number=numTrials, globals=globals()  )}")
print(f"pytorch saved model time:  {timeit('saved()',          number=numTrials, globals=globals()  )}")

On my computer this was the output:

pure python function time: 7.517180057002406
pytorch model time:        7.449000566004543
pytorch jit model time:    13.549996405999991
pytorch saved model time:  14.33033561299817

I wrote the same code in C++ and it ran in less than 1 second.
I also tried importing the pytorch model in libtorch C++ and it took 13 seconds just like in python.
Is there any way to make for loops like this faster in pytorch?


I’m not sure, what kind of optimizations are currently implemented regarding these plain Python loops in the JIT. :confused:

Note that you are currently not using any tensors, but Python literals.
I assume the first optimizations would go into the usage of tensors.
If you define val = torch.zeros(1), you should see a speedup of the jitted models in comparison to the eager model.

However, you will not see any speedup compared to the pure Python function, since your code is not vectorized.
Vectorization usually gives you a good speedup and loops should be generally avoided, if possible.