Hi,
I am trying to use libtorch. According to the docs torch.jit.trace
cannot handle cases with control flow, I mean to use torch.jit.scriptmodule
. This is my code:
class Resnet18(nn.Module):
def __init__(self):
...
def forward(self, inten):
...
class SerializeModule(torch.jit.ScriptModule):
def __init__(self):
super(SerializeModule, self).__init__()
save_pth = 'res/model_final_naive.pth'
self.model = Resnet18(n_classes=10, pre_act=True)
state_dict = torch.load(save_pth)
self.model.load_state_dict(state_dict)
self.model.eval()
@torch.jit.script_method
def forward(self, inten):
out = self.model.forward(inten)
return out
def main():
serial = SerializeModule()
serial.save('res/model_serial.pt')
if __name__ == "__main__":
main()
This gives the error message of like this:
File "serialize.py", line 24, in <module>
main()
File "serialize.py", line 20, in main
serial = SerializeModule()
File "/home/coin/build/miniconda3/envs/py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 1047, in init_then_register
_create_methods_from_stubs(self, methods)
File "/home/coin/build/miniconda3/envs/py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 1012, in _create_methods_from_stubs
self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError:
attribute lookup is not defined on python value of type 'Resnet18':
@torch.jit.script_method
def forward(self, inten):
out = self.model.forward(inten)
~~~~~~~~~~~~~~~~~~ <--- HERE
return out
What is the good way to use torchscript please? And what is the reason that brings about this error?