Returning early if tensor contains a value above a certain threshold in a jitted function

Hi all,

I guess I could have titled my topic: “Can torch.jit be used for the same purpose as numba’s njit or is it something strictly used to optimize models for inference?”, but this sounded a bit too long.

I am trying to understand why this torch jitted function is so very slow. I think this code snippet is better than any attempt to explain it:

import timeit
import torch


def if_max_torch_base(tensor: torch.Tensor, threshold=torch.Tensor) -> torch.Tensor:
    "Just the base pytorch way"
    return torch.any(tensor > threshold)


@torch.jit.script
def if_max_torch_jit(tensor: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
    "Attempt at exploiting torch.jit"
    for x in tensor:
        if x > threshold:
            return torch.tensor(True)  # no need to test further, let's just return early
    return torch.tensor(False)

if_max_torch_jit = torch.jit.trace(if_max_torch_jit, (torch.rand(5), torch.tensor(0.5)))


def main():
    tensor = torch.linspace(0, 1, 1_000_000, device="cuda")

    for name, func in (
        ["base", if_max_torch_base],
        ["jit", if_max_torch_jit],
    ):
        for threshold in (0.5, 1.0):
            print(name, threshold)
            t = torch.tensor(threshold, device="cuda")
            timer = timeit.Timer(lambda: func(tensor, t))
            print(timer.repeat(repeat=3, number=100))


main()

And this is the output:

base 0.5
[0.0002641710452735424, 0.00016550999134778976]
base 1.0
[0.00017875991761684418, 0.00016467086970806122]
jit 0.5
[70.17099338211119, 71.27563373814337]
jit 1.0
[139.18801530217752, 139.25591901200823]

Why does this jitted function have abysmal performance? Am I just completely misusing torch.jit? Can torch.jit can event be used for this purpose?

This is what if_max_torch_jit.code returns:

def if_max_torch_jit(tensor: Tensor,
    threshold: Tensor) -> Tensor:
  _0 = uninitialized(Tensor)
  _1 = torch.len(tensor)
  _2 = False
  _3 = _0
  _4 = 0
  _5 = torch.gt(_1, 0)
  while _5:
    x = torch.select(tensor, 0, _4)
    if bool(torch.gt(x, threshold)):
      _6, _7, _8 = False, True, torch.tensor(True)
    else:
      _6, _7, _8 = True, False, _0
    _9 = torch.add(_4, 1)
    _5, _2, _3, _4 = torch.__and__(torch.lt(_9, _1), _6), _7, _8, _9
  if _2:
    _10 = _3
  else:
    _10 = torch.tensor(False)
  return _10

This seems overly convoluted, but I am not familiar with low-level programming at all…

Thanks for reading me and thanks in advance for educating me. :slight_smile: