class MultiOp(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
multiNum = 2
ctx.save_for_backward(multiNum)
result = multiNum * input
return result
@staticmethod
def backward(ctx, grad_output):
multiNum = ctx.saved_variables
result = grad_output / multiNum
return result
multiOp = MultiOp.apply
class Net(torch.jit.ScriptModule):
def __init__(self):
super(Net, self).__init__()
@torch.jit.script_method
def forward(self, x):
x = multiOp(x)
return x
scriptNet = Net()
torch.jit.save(scriptNet, "sample.pt")
RuntimeError Traceback (most recent call last)
<ipython-input-55-40f7e60ca520> in <module>
----> 1 torch.jit.save(scriptNet, "sample.pt")
~/miniconda3/envs/dingyongchao/lib/python3.7/site-packages/torch/jit/__init__.py in save(m, f, _extra_files)
196 (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
197 (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
--> 198 m.save(f, _extra_files=_extra_files)
199 else:
200 ret = m.save_to_buffer(_extra_files=_extra_files)
~/miniconda3/envs/dingyongchao/lib/python3.7/site-packages/torch/jit/__init__.py in save(self, *args, **kwargs)
1203
1204 def save(self, *args, **kwargs):
-> 1205 return self._c.save(*args, **kwargs)
1206
1207 def save_to_buffer(self, *args, **kwargs):
RuntimeError:
could not export python function call MultiOp. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__.:
@torch.jit.script_method
def forward(self, x):
x = multiOp(x)
~~~~~~~ <--- HERE
return x
So I want to how to define a Function ?