Converting custom nn.Module to torchscript

Trying to convert my model for usage in a C++ application.

I have a model file I cannot change, but trying to patch the Pytorch code to be torchscript compatible, but still use the same model weights.

I have simplified the issue I’m seeing into a small example script here.

Indexing into an nn.ModuleList requires a type hint for the left hand side. So I defined a new interface MyEncoderModuleInterface. However, once I do this, the conversion falls apart because python is unable to find other members, like empty on this variable submodule.

from torch.nn import ModuleList
import torch


@torch.jit.interface
class MyEncoderModuleInterface(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass

class MyEncoderLayer(torch.nn.Module):
    def __init__(self, d=32):
        super().__init__()
        self.d = d
        self.empty = False
        self.layer = torch.nn.Linear(d, d)
    
    def forward(self, x: torch.Tensor):
        if torch.sum(x) > 0:
            return self.layer(x)
        return self.layer(2 * x)

class MyModule(torch.nn.Module):
    def __init__(self, d):
        super().__init__()
        self.d = d
        self.empty = False
        self.layers = torch.nn.ModuleList([MyEncoderLayer(32), MyEncoderLayer(32)])

    def forward(self, x: torch.Tensor, c: int):
        result = torch.rand(self.d)
        for i in range(c):
            submodule: MyEncoderModuleInterface = self.layers[i]
            if not submodule.empty:
                result = submodule.forward(x)
        return result 

d = 32
script = torch.jit.script(MyModule(d))
out = script.forward(torch.randn(d), 2)

The error I see is:

RuntimeError: 
'__torch__.___torch_mangle_165.MyEncoderModuleInterface' object has no attribute or method 'empty'.:
  File "/Users/me/code/torch_script_example/question.py", line 33
        for i in range(c):
            submodule: MyEncoderModuleInterface = self.layers[i]
            if not submodule.empty:
                   ~~~~~~~~~~~~~~~ <--- HERE
                result = submodule.forward(x)
        return result 

I have tried defining a dummy variable like this so that MyEncoderModuleInterface looks the same:

@torch.jit.interface
class MyEncoderModuleInterface(torch.nn.Module):
    def __init__(self) -> None:
        self.empty = False
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass

but this leads to a different error:

RuntimeError: 
interfaces declarations should contain 'pass' statement.:
  File "/Users/will/code/torch_script_example/question.py", line 7
    def __init__(self) -> None:
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~
        self.empty = False
        ~~~~~~~~~~~~~~~~~~ <--- HERE

So it doesn’t seem like I can define member variables on an interface like this, which makes sense.

How can I torchscript-ify code like this? What am I missing?

I cannot reproduce the issue and get:

TypeError: <class '__main__.MyEncoderModuleInterface'> is a built-in class

running your code in 2.0.0.
After removing MyEncoderModuleInterface from submodule: MyEncoderModuleInterface = self.layers[i] I get:

RuntimeError: 
Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
  File "/tmp/ipykernel_7066/2843282298.py", line 23
        result = torch.rand(self.d)
        for i in range(c):
            submodule = self.layers[i]
                        ~~~~~~~~~~~~~~ <--- HERE
            if not submodule.empty:
                result = submodule.forward(x)

Next, I replaced the loop with for submodule in self.layers: which seems to work.

Wait, really? When I run the code verbatim from what I posted, I see:

❯ python question.py 
Traceback (most recent call last):
  File "/Users/me/code/torch_script_example/question.py", line 38, in <module>
    script = torch.jit.script(MyModule(d))
  File "/opt/miniconda3/envs/torchscript/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/opt/miniconda3/envs/torchscript/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/opt/miniconda3/envs/torchscript/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/opt/miniconda3/envs/torchscript/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
'__torch__.MyEncoderModuleInterface' object has no attribute or method 'empty'.:
  File "/Users/me/code/torch_script_example/question.py", line 33
        for i in range(c):
            submodule: MyEncoderModuleInterface = self.layers[i]
            if not submodule.empty:
                   ~~~~~~~~~~~~~~~ <--- HERE
                result = submodule.forward(x)
        return result 

❯ which python
/opt/miniconda3/envs/torchscript/bin/python

❯ python --version
Python 3.10.10

❯ python -c "import torch; print(torch.__version__)"
2.0.0

❯ uname -a
Darwin macbook-pro-4.lan 22.1.0 Darwin Kernel Version 22.1.0: Sun Oct  9 20:15:09 PDT 2022; root:xnu-8792.41.9~2/RELEASE_ARM64_T6000 arm64

Would something other than the pytorch version matter? Perhaps that I’m on M1, or Mac OS X, or Python version … ?

I am just a little worried why I’d be seeing different outputs with same input.

Aside from that, your tip does seem to work! I tried this and it does run without error:

    def forward(self, x: torch.Tensor, c: int):
        result = torch.rand(self.d)
        for i, submodule in enumerate(self.layers):
            if i < c and not submodule.empty:
                result = submodule.forward(x)
        return result 

which seems functionally equivalent to the code I originally posted :tada:

But would still like to know why my interpreter is having trouble and yours wasn’t!

It could be indeed related to the Python version as I’m using 3.8, but unsure.
Good to hear it’s working now!

Yeah for posterity, did try things on 3.8. Different results! Good to know.

Thanks for the help here.

1 Like