[JIT] [Mobile] Is isinstance() supposed to work with TorchScript

Hello there,

I’ve faced similar issues during deploying scripted model on Android.

  1. The first question is “How to check whole grap or code of scripted model?”.
    ScriptedModule.code provides only upper level code, but inside stacktrace on mobile I see more information about how the code looks like.
  2. Second is about isinstance(). My model has plenty of isinstance(tensor, torch.Tensor) or such which are converted to in tensor1 = unchecked_cast(Tensor, tensor) script code. Here comes an error.
  3. Why there’s no error while exporting/scripting of module?

isinstance is supported but its result is static. It is useful for Modules that have different attribute types passed in, e.g.

class M(torch.nn.Module):
    def __init__(self, x):
        super().__init__()
        self.x = x

    def forward(self):
        if isinstance(self.x, List[str]):
            return self.x[2]
        else:
            return self.x + 2

print(torch.jit.script(M(['bye'])).graph)
print(torch.jit.script(M(2)).graph)

The compiler is able to see the isinstance check and evaluate it at compile time and remove the unused branch. The graphs show this:

graph(%self : ClassType<M>):
  %3 : str = prim::Constant[value="hi"]() # ../test.py:22:19
  return (%3)

graph(%self : ClassType<M>):
  %4 : int = prim::Constant[value=2]() # ../test.py:24:28
  %3 : int = prim::GetAttr[name="x"](%self)
  %5 : int = aten::add(%3, %4) # ../test.py:24:19
  return (%5)

As for 1), we recently changed the behavior so that functions in .code and .graph appear as function calls (previously we were inlining the function bodies). So we’re still missing the functionality to show the called functions. For now you can re-enable inlining to see the entire graph:

def other_fn(x):
    return x + 10

# Change the inlining mode before you compile
torch._C._jit_set_inline_everything_mode(True)

@torch.jit.script
def fn(x):
    return other_fn(x)

print(fn.code)
print(fn.graph)

You can track this bug in https://github.com/pytorch/pytorch/issues/29750.

1 Like

You can also print out a model directly with

class M(nn.Module):
    def other_fn(self, x):
        return x + 10

    def forward(self, x):
        return self.other_fn(x)

m = torch.jit.script(M())
print(m._c.dump())
1 Like

This one is unclear to me.

How could we annotate input for __init__() and output of forward() functions? The Union[int, List[str]] typing is unsupported

Thanks a lot! That clarified the way how to see the whole code

This returns None but anyway I haven’t looked for that functionality :stuck_out_tongue:

__init__ on nn.Modules runs in Python (torch.jit.script only sees the module after it has been initialized), so you can annotate that with whatever Python annotations you want (but they won’t be enforced by the compiler at all, for that you should use something like mypy).

For methods that are compiled (e.g. forward and anything it calls), the return types can be deduced from the code. If you want to explicitly write it out, you can use any of the type annotations listed here.

Unions aren’t supported so they won’t work in TorchScript. As a workaround you could do something like Tuple[Optional[int], Optional[List[str]]].