I am trying to learn more about the JIT compiler and was implementing the examples from the documentation, particularly this one (from https://pytorch.org/docs/stable/jit.html#torch.jit.ScriptModule):
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
The execution of the code resulted in this error:
Do you guys have any ideas on what is wrong?