Hi all,
I guess I could have titled my topic: “Can torch.jit be used for the same purpose as numba’s njit or is it something strictly used to optimize models for inference?”, but this sounded a bit too long.
I am trying to understand why this torch jitted function is so very slow. I think this code snippet is better than any attempt to explain it:
import timeit
import torch
def if_max_torch_base(tensor: torch.Tensor, threshold=torch.Tensor) -> torch.Tensor:
"Just the base pytorch way"
return torch.any(tensor > threshold)
@torch.jit.script
def if_max_torch_jit(tensor: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
"Attempt at exploiting torch.jit"
for x in tensor:
if x > threshold:
return torch.tensor(True) # no need to test further, let's just return early
return torch.tensor(False)
if_max_torch_jit = torch.jit.trace(if_max_torch_jit, (torch.rand(5), torch.tensor(0.5)))
def main():
tensor = torch.linspace(0, 1, 1_000_000, device="cuda")
for name, func in (
["base", if_max_torch_base],
["jit", if_max_torch_jit],
):
for threshold in (0.5, 1.0):
print(name, threshold)
t = torch.tensor(threshold, device="cuda")
timer = timeit.Timer(lambda: func(tensor, t))
print(timer.repeat(repeat=3, number=100))
main()
And this is the output:
base 0.5
[0.0002641710452735424, 0.00016550999134778976]
base 1.0
[0.00017875991761684418, 0.00016467086970806122]
jit 0.5
[70.17099338211119, 71.27563373814337]
jit 1.0
[139.18801530217752, 139.25591901200823]
Why does this jitted function have abysmal performance? Am I just completely misusing torch.jit? Can torch.jit can event be used for this purpose?
This is what if_max_torch_jit.code
returns:
def if_max_torch_jit(tensor: Tensor,
threshold: Tensor) -> Tensor:
_0 = uninitialized(Tensor)
_1 = torch.len(tensor)
_2 = False
_3 = _0
_4 = 0
_5 = torch.gt(_1, 0)
while _5:
x = torch.select(tensor, 0, _4)
if bool(torch.gt(x, threshold)):
_6, _7, _8 = False, True, torch.tensor(True)
else:
_6, _7, _8 = True, False, _0
_9 = torch.add(_4, 1)
_5, _2, _3, _4 = torch.__and__(torch.lt(_9, _1), _6), _7, _8, _9
if _2:
_10 = _3
else:
_10 = torch.tensor(False)
return _10
This seems overly convoluted, but I am not familiar with low-level programming at all…
Thanks for reading me and thanks in advance for educating me.