Converting to Torch Script via Annotation

Hi, I run the code:

import torch

class MyModule(torch.jit.ScriptModule):
        def __init__(self, N, M):
                super(MyModule, self).__init__()
                self.weight = torch.nn.Parameter(torch.rand(N, M))

        @torch.jit.script_method
        def forward(self, input):
                if input.sum() > 0:
                        output = self.weight.mv(input)
                else:
                        output = self.weight + input
                return output

my_script_module = MyModule(10, 20)

my_script_module.save("mymodule.pt")

However, I get an error:

Traceback (most recent call last):
  File "torchscript_annotation_convert.py", line 16, in <module>
    my_script_module = MyModule(10, 20)
  File "/home/ssm/.conda/envs/ocrcpp/lib/python3.6/site-packages/torch/jit/__init__.py", line 907, in init_then_register
    _create_methods_from_stubs(self, methods)
  File "/home/ssm/.conda/envs/ocrcpp/lib/python3.6/site-packages/torch/jit/__init__.py", line 868, in _create_methods_from_stubs
    self._create_methods(defs, rcbs, defaults)
RuntimeError: 
expected a boolean expression for condition but found Tensor, to use a tensor in a boolean expression, explicitly cast it with `bool()`:
@torch.jit.script_method
def forward(self, input):
	if input.sum() > 0:
     ~~~~~~~~~~~~~~~ <--- HERE
		output = self.weight.mv(input)
	else:
		output = self.weight + input
	return output

I guess, input.sum() > 0 should be bool, instead of tensorbool.

Here is my environments

Package    Version          
---------- -----------------
certifi    2018.10.15       
cffi       1.11.5           
mkl-fft    1.0.6            
mkl-random 1.0.1            
numpy      1.15.4           
pip        18.1             
pycparser  2.19             
setuptools 40.6.2           
torch      1.0.0.dev20181119
wheel      0.32.3

Hope somebody help me out.

Besides, I use python3.6, ubuntu 14.04

As suggested in the error, you just want to cast it to a proper boolean:
if bool(input.sum() < 0): should work no?

Yes, it works. Thank you @albanD
By the way, the code in pytorch tutorial seems not right.

import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_script_module = MyModule()

it should be replaced by the following code:

import torch

class MyModule(torch.jit.ScriptModule):
	def __init__(self, N, M):
		super(MyModule, self).__init__()
		self.weight = torch.nn.Parameter(torch.rand(N, M))

	@torch.jit.script_method
	def forward(self, input):
		if bool(input.sum() > 0):
			output = self.weight.mv(input)
		else:
			output = self.weight + input
		return output

my_script_module = MyModule(10, 10)
1 Like