Trying to convert my model for usage in a C++ application.
I have a model file I cannot change, but trying to patch the Pytorch code to be torchscript compatible, but still use the same model weights.
I have simplified the issue I’m seeing into a small example script here.
Indexing into an nn.ModuleList
requires a type hint for the left hand side. So I defined a new interface MyEncoderModuleInterface
. However, once I do this, the conversion falls apart because python is unable to find other members, like empty
on this variable submodule
.
from torch.nn import ModuleList
import torch
@torch.jit.interface
class MyEncoderModuleInterface(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class MyEncoderLayer(torch.nn.Module):
def __init__(self, d=32):
super().__init__()
self.d = d
self.empty = False
self.layer = torch.nn.Linear(d, d)
def forward(self, x: torch.Tensor):
if torch.sum(x) > 0:
return self.layer(x)
return self.layer(2 * x)
class MyModule(torch.nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
self.empty = False
self.layers = torch.nn.ModuleList([MyEncoderLayer(32), MyEncoderLayer(32)])
def forward(self, x: torch.Tensor, c: int):
result = torch.rand(self.d)
for i in range(c):
submodule: MyEncoderModuleInterface = self.layers[i]
if not submodule.empty:
result = submodule.forward(x)
return result
d = 32
script = torch.jit.script(MyModule(d))
out = script.forward(torch.randn(d), 2)
The error I see is:
RuntimeError:
'__torch__.___torch_mangle_165.MyEncoderModuleInterface' object has no attribute or method 'empty'.:
File "/Users/me/code/torch_script_example/question.py", line 33
for i in range(c):
submodule: MyEncoderModuleInterface = self.layers[i]
if not submodule.empty:
~~~~~~~~~~~~~~~ <--- HERE
result = submodule.forward(x)
return result
I have tried defining a dummy variable like this so that MyEncoderModuleInterface
looks the same:
@torch.jit.interface
class MyEncoderModuleInterface(torch.nn.Module):
def __init__(self) -> None:
self.empty = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
but this leads to a different error:
RuntimeError:
interfaces declarations should contain 'pass' statement.:
File "/Users/will/code/torch_script_example/question.py", line 7
def __init__(self) -> None:
~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.empty = False
~~~~~~~~~~~~~~~~~~ <--- HERE
So it doesn’t seem like I can define member variables on an interface like this, which makes sense.
How can I torchscript-ify code like this? What am I missing?