No graph attribute when use torch.jit._overload_method

PyTorch version: 1.6.0.dev20200407+cu101

When use torch.jit._overload_method to overload functions, it seems that there is no graph attribute in the output torchscript.

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.jit._overload_method  # noqa: F811
    def forward(self, x: int) -> int:  # noqa: F811
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: F811
        pass

    def forward(self, x):  # noqa: F811
        original = x

        if isinstance(original, int):
            return original + 2
        else:
            return original.sum()

model = MyModule()
s = torch.jit.script(model)
torch._C._jit_pass_inline(s.graph)  # error happens here

error message:

torch.nn.modules.module.ModuleAttributeError: 'RecursiveScriptModule' object has no attribute 'graph'

Hi, this is is because there is no single “graph”, there are two separate graphs, one for each overload. I could maybe add an api for something like model.forwards.graphs, or model.forward.graph_for_types(1) to get the graph for a specific overload.

Generally, overloads should be usable and work as you expect when they are contained in other modules but you may run into a little bit of difficulty if you are interacting with them as a top level module. They’re still an internal feature but will be cleaned up and released probably in the next release or the one after.

Thank you very much:)
I put overloads into sub modules and find there is one single graph for one kind of overload as you said:

class MySubModule(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.jit._overload_method  # noqa: F811
    def forward(self, x: int) -> int:  # noqa: F811
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: F811
        pass

    def forward(self, x):  # noqa: F811
        original = x

        if isinstance(original, int):
            return original + 2
        else:
            return original.sum()

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub_module = MySubModule()

    def forward(self, x):
        return self.sub_module(x)


model = MyModule()
s = torch.jit.script(model)
torch._C._jit_pass_inline(s.graph)
print(s.graph)

output:

graph(%self : __torch__.MyModule,
      %x.1 : Tensor):
      %2 : __torch__.MySubModule = prim::GetAttr[name="sub_module"](%self)
      %14 : None = prim::Constant()
      %15 : Tensor = aten::sum(%x.1, %14) 
  return (%15)

Since default type is torch.Tensor, the corresponding graph is obtained. So is there any ways to obtain one single torchscript for all overloads for now?

The MyModule you’re using only takes a Tensor. If you want to access the int graph you could try using a different module with the submodule that takes in an int.

Got it, thank you very much:)

Sorry to dredge this one up but I had exactly the same problem, I was hoping I could ask @eellison how this works exactly?

Generally, overloads should be usable and work as you expect when they are contained in other modules

I get the idea: Once you call forward, then the torchscript can inspect types and choose the right graph? But is it possible to get a reference for some documentation to fully understand how this works (or to the code that does this)? I don’t quite understand why forward is treated differently - when this error occurs I can still call the other functions which have been overloaded.

Hi @Padarn_Wilson, the reason there aren’t good docs for this is because it wasn’t fully finished. there are still a few rough edges as you are encountering.

I don’t quite understand why forward is treated differently - when this error occurs I can still call the other functions which have been overloaded.

Hmm, what do you mean exactly ? Not sure I follow.

Hey @Elias_Ellison thanks for the calrification.

From the discussion above I understood that when .forward is invoked it will also do the job of inferring which overloaded functions to call based on the types present… but rereading it I realise I had my understanding a bit confused. I think it is clear now.