How does the JIT get the function types?

Does it look at the raw source on disk? How can we give it a custom callable type for a function and force it to recompile a copy of the function body with the new types (but a new name)?

Does it look at the raw source on disk?

Yes. JIT uses the inspect module to look at the source code for the function, and determines the function schema (types of the inputs and outputs) based on type annotations (either Python3 or mypy style).

How can we give it a custom callable type for a function and force it to recompile a copy of the function body with the new types (but a new name)?

Could you provide a code sample to clarify this question? What is it that you want to do?

1 Like

I’d like to be able to say something like:

def generic_loop(self, inputs, state):
    " ^^^ I'm trying to reuse this. ^^^ "
    outs = []
    for x in inputs:
        out, state = self.fwd(x, state)
        outs.append(out)
    return outs, state
class Cell(nn.Module):
    def __init__(self, forward_sig):
        self.forward = apply_cell_sig(generic_loop, forward_sig)
        # self.forward : (TCustomCell, List[TCustomInput], TCustomState) -> TCustomOut

class CustomCell(Cell):
    def __init__(self):
        super().__init__(
            get_cell_sig(
                CustomCell, TCustomInput, TCustomState, TCustomOut))
    def fwd(self, input: TCustomInput, state: TCustomState):
        return blah(input, state)

More generally, I’m looking for a way to treat a function body as a “template” that I can provide specific types for, and ask the JIT explicitly to give me a new version for it. So it doesn’t just see its full name and go “I’ve already seen this; here’s the cached version”.

By your confirmation though, this’s probably very difficult. I couldn’t get the JIT to accept a function that takes in a nn.Module instance as an argument either (besides self. accesses) so there’re multiple issues here.

More generally, I’m looking for a way to treat a function body as a “template” that I can provide specific types for, and ask the JIT explicitly to give me a new version for it. So it doesn’t just see its full name and go “I’ve already seen this; here’s the cached version”.

I don’t think it is a publicly advertised feature, but you could try using @torch.jit._overload. You can find examples of how to use this decorator in test/test_jit.py, like test_function_overloads. The caveat here is that the one body you supply will be used with all overloaded signatures because there is no concept of TypeVar in the JIT typing system. So you might have do type refinement to write one body that works for all types (see my_conv in test_function_overloading_isinstance, also in test_jit.py).

By your confirmation though, this’s probably very difficult. I couldn’t get the JIT to accept a function that takes in a nn.Module instance as an argument either (besides self. accesses) so there’re multiple issues here.

Yeah, modules cannot be passed around because there is no Module type in the JIT type system. There is no way to annotate a function correctly to make this work.

1 Like

This looks promising!

Can you elaborate on what’s happening here?

        # TODO: pyflakes currently does not compose @overload annotation with other
        # decorators. This is fixed on master but not on version 2.1.1.
        # Next version update remove noqa and add @typing.overload annotation

        @torch.jit._overload  # noqa: F811
        def test_simple(x1):  # noqa: F811
            # type: (int) -> int
            pass

        @torch.jit._overload  # noqa: F811
        def test_simple(x1):  # noqa: F811
            # type: (float) -> float
            pass

        def test_simple(x1):  # noqa: F811
            return x1

        def invoke_function():
            return test_simple(1.0), test_simple(.5)

        self.checkScript(invoke_function, ())

        # testing that the functions are cached
        compiled_fns_1 = torch.jit._script._get_overloads(test_simple)
        compiled_fns_2 = torch.jit._script._get_overloads(test_simple)
        # ^^^ HERE HERE HERE ^^^ #
        for a, b in zip(compiled_fns_1, compiled_fns_2):
            self.assertIs(a.graph, b.graph)

Why could the two successive invocations of torch.jit._script._get_overloads(test_simple) get different results? At that point, there’s only one thing attached to the test_simple name.

I was hoping there’d be a way to work around this specifically :frowning: The JIT already works with a self argument of an nn.Module subtype. How can we make this available for free functions? This restriction greatly limits the composability of free functions if I want to use script. Honestly not sure at that point if there’d be a worthwhile speed either. This’s unrelated, but I’ve had some trouble understanding what the JIT does with what it sees, and intuiting where my code might be giving it more trouble than it can handle.

Instead of that, you could also look into

  • Inject your code into linecache under different module names like IPython does. This isn’t exactly documented, but clean as far as the JIT is concerned.
  • Get the AST using git_jit_def, meddle with the AST, and use _jit_script_compile to compile yourself (you probably need the resolver too), so you basically mimic torch.jit.script except the caching. This is fairly invasive into JIT internals, but at least it is straightforward to your cause.

Best regards

Thomas

The short answer is not really. The JIT supports classes insofar as they are static - if you JIT-compile Modules, you don’t give it the class source but rather an instance. It will then go through the data members and see and process their types, assuming they will be fixed (which isn’t the case in Python in general).
Maybe you might find more success in taking the “JIT” part more literally, e.g. JIT-compiling a local function from a dynamic function.

I’m affraid the JIT isn’t a magic bullet for optimization but “only” does specific things that are highly desirable.

Can you elaborate on this a bit more?

Absolutely. I meant that what those “specific things” are is a bit unclear to me. And doesn’t seem to be a goal of the documentation.

There are three - the two overloads and the implementation. And as you can see in the test, they should not have different results in terms of the ordering of functions nor the graphs produced by compiling said functions.

Well, so one thing could be to actually make it a method of the class of self before you JIT that class / the method.

My chance to sell advanced PyTorch courses. Your chance to be a hero. :slight_smile:

More seriously we discuss this a bit in section 15.3.1 Interacting with the PyTorch JIT / What to expect from moving beyond classic Python/PyTorch of out book that you can download in exchange for anwering a few questions.
To summarize that, the main use-cases I see are

  • exporting stuff,
  • getting rid of the GIL for nicer multithreading in deployment,
  • optimize certain patterns (e.g. pointwise ops in RNNs and elsewhere are one thing we had relatively early), but people are working on expanding this (e.g. to cover reductions).

Best regards

Thomas

1 Like

There are three bodies in the text, but I don’t see why the two torch.jit._script._get_overloads(test_sample) calls could return different compiled objects even if the cache was disabled.

This bit’s interesting. I might’ve been misunderstanding the cache behavior here. If I script different instances of the same class at different times, does it use the cache or treat every instance separately? script, not trace. Because if it treats instances separately, could assigning different modules to attributes be a workaround for JITing functions that make use of modules from arguments (by not passing them through arguments…)?

Which is why I’d imagine being able to “JIT” the same function with different types. Am I misinterpreting this?

Looks great! Thanks for the rec.

If I script different instances of the same class at different times, does it use the cache or treat every instance separately?

Instances are not cached, but their types in the JIT system are. Every time you script an nn.Module, a type is created for in the JIT type system. This is reused if multiple instances of the same module are scripted in the same program, or even the same instance is scripted twice. The ScriptModule returned by torch.jit.script is always fresh and never from a cache, but it might refer to a type object that is cached.