Why torch.jit is so slow?

Hello! Today I compared numba.jit and torch.jit and was very surprised. What am I doing wrong?

import torch
from numba import jit

@torch.jit.script
def torch_jit_sum(x : torch.Tensor):
    res = torch.zeros_like(x[0, 0])
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            res += x[i, j]
    return res

def blablabla(x):
    with torch.no_grad():
        return torch_jit_sum(x)

def loop_sum(x):
    res = torch.zeros_like(x[0, 0])
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            res += x[i, j]
    return res

@jit
def numba_sum(x):
    res = 0.0
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            res += x[i, j]
    return res

1 Like

Hi @Daulbaev! The reason torch.jit is slow in your example, is because it’s not designed for this particular use-case :slight_smile: The kinds of speed ups you will see with torch.jit are situations, e.g., when you have a number of pointwise operations, and torch.jit will be able to fuse them together and eliminate overhead and memory traffic. torch.jit, at this point in time, is not designed to take pointwise loops as you’ve written here, and compile them into machine code directly.

Thank you for a quick response. Do you have an example of a function where torch.jit on CPU is faster than numba.jit?

Given that numba jit compiles single cuda kernels, it’s going to be at leas as fast in execution.
However, for many things, the expressive power of PyTorch is much greater and the JIT will take those ops and optimize them.

Best regards

Thomas

1 Like