Error in @torch.jit.script

I am trying to learn more about the JIT compiler and was implementing the examples from the documentation, particularly this one (from TorchScript — PyTorch 2.1 documentation):

import torch
@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

The execution of the code resulted in this error:


Do you guys have any ideas on what is wrong?

We made a change to make the bool casting more strict, try: if bool(x.max() > y.max()): instead. We need to update the docs to this effect, I’ll file a GH issue.

1 Like