import torch
from numba import jit
@torch.jit.script
def torch_jit_sum(x : torch.Tensor):
res = torch.zeros_like(x[0, 0])
for i in range(x.shape[0]):
for j in range(x.shape[1]):
res += x[i, j]
return res
def blablabla(x):
with torch.no_grad():
return torch_jit_sum(x)
def loop_sum(x):
res = torch.zeros_like(x[0, 0])
for i in range(x.shape[0]):
for j in range(x.shape[1]):
res += x[i, j]
return res
@jit
def numba_sum(x):
res = 0.0
for i in range(x.shape[0]):
for j in range(x.shape[1]):
res += x[i, j]
return res
Hi @Daulbaev! The reason torch.jit is slow in your example, is because it’s not designed for this particular use-case The kinds of speed ups you will see with torch.jit are situations, e.g., when you have a number of pointwise operations, and torch.jit will be able to fuse them together and eliminate overhead and memory traffic. torch.jit, at this point in time, is not designed to take pointwise loops as you’ve written here, and compile them into machine code directly.
Given that numba jit compiles single cuda kernels, it’s going to be at leas as fast in execution.
However, for many things, the expressive power of PyTorch is much greater and the JIT will take those ops and optimize them.