Should adding @torch.jit.export annotation on a single method make it export all module properties?

I run into this interesting issue when trying to export a simple nn.Module. Here is the code:

import torch
import torch.nn as nn

class MyMod(nn.Module):
    def __init__(self) -> None:
        self.x = 1
        self.y = 2

    def forward(self, x):
        return 2 * x

    # @torch.jit.export  # <-- UNCOMMENT HERE
    def get_x(self):
        return self.x

mod = MyMod()
traced = torch.jit.trace(mod, [torch.tensor([1.0])])


At the current form the code fails on all 3 print statements - AS EXPECTED.

Now when I uncomment the torch.jit.export annotation all 3 statements pass.

  • I understand get_x(), because I explicitly exported it.
  • I kind of get x, because it’s a dependency for get_x()
  • I totally don’t get y, because it’s just an independent non-tensor property.

Tested on torch==2.0.1.

Most of the time I wouldn’t really care about some extra properties exported alongside my model.
The problem occurs when I try to use some external libraries like HuggingFace transformers.
For example, when I add ViTForImageClassification as a submodule of MyMod all it’s properties
are going to be recursively traced and this leads to annotation errors like the one below:

Unknown type name 'nn.Module':
  File "blah/site-packages/transformers/", line 1197
    def base_model(self) -> nn.Module:

This is just some innocent annotated property that gets cought in the crossfire.

Any idea if this is an expected behaior. Should I report a bug? I already found an ugly workaround, but this feels like something that should have a simpler solution.

Workaround for anyone interested

Just create a new MyExports module and move the exports there. Since the behavior is recursive it will be limited to the MyExports module and leave the ViT model alone. You can then access them like mod.exports.get_x().

I just added print(type(traced)) to the code and saw that:

  • without export decorator it’s <class 'torch.jit._trace.TopLevelTracedModule'>
  • with decorator it’s <class 'torch.jit._script.RecursiveScriptModule'>

So even the result type differes. Is it the case that torch.jit.trace silently enters scripting mode when it sees any export? It’s still weird, because it doesn’t behave exactly like torch.jit.script, but it gets close.

I also noticed that I can get what I want by calling torch.jit.trace_module explicitly like this:

traced = torch.jit.trace_module(mod, {"forward": [torch.tensor([1.0])], "get_x": []})

Is there a decorator that I could put on get_x to achieve the same result with torch.jit.trace?

I found the piece of code responsible for this behavior:

def make_module(mod, _module_class, _compilation_unit):
    if isinstance(mod, ScriptModule):
        return mod
    elif torch._jit_internal.module_has_exports(mod):

        infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
        return torch.jit._recursive.create_script_module(
        if _module_class is None:
            _module_class = TopLevelTracedModule
        return _module_class(mod, _compilation_unit=_compilation_unit)

This is called when tracing a module and it:

  • checks whether the module contains any export decoracors
  • if yes, then calls create_script_module
  • if not, then it does the regular tracing?

Not sure what’s the difference between regular torch.jit.script and the way it is done here, but there is this is_tracing=True flag, which probably makes the difference.

So it seems like this is explicitly coded behavior, but I’m not sure if it’s really desired. If someone explicitly uses tracing then there is probably a reason to that. It would be probably safer to throw an error that you can’t use exports when tracing. I don’t know.