Hi, I run the code:
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_script_module = MyModule(10, 20)
my_script_module.save("mymodule.pt")
However, I get an error:
Traceback (most recent call last):
File "torchscript_annotation_convert.py", line 16, in <module>
my_script_module = MyModule(10, 20)
File "/home/ssm/.conda/envs/ocrcpp/lib/python3.6/site-packages/torch/jit/__init__.py", line 907, in init_then_register
_create_methods_from_stubs(self, methods)
File "/home/ssm/.conda/envs/ocrcpp/lib/python3.6/site-packages/torch/jit/__init__.py", line 868, in _create_methods_from_stubs
self._create_methods(defs, rcbs, defaults)
RuntimeError:
expected a boolean expression for condition but found Tensor, to use a tensor in a boolean expression, explicitly cast it with `bool()`:
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
~~~~~~~~~~~~~~~ <--- HERE
output = self.weight.mv(input)
else:
output = self.weight + input
return output
I guess, input.sum() > 0
should be bool, instead of tensorbool.
Here is my environments
Package Version
---------- -----------------
certifi 2018.10.15
cffi 1.11.5
mkl-fft 1.0.6
mkl-random 1.0.1
numpy 1.15.4
pip 18.1
pycparser 2.19
setuptools 40.6.2
torch 1.0.0.dev20181119
wheel 0.32.3
Hope somebody help me out.