I tried to compile a module that contains a custom op defined by torch.autograd.Function
:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class mul2(Function):
@staticmethod
def forward(ctx, x):
return x * 2
@staticmethod
def backward(ctx, dx):
return dx * 2
def f(a, b):
c = a + b
d = mul2.apply(c)
e = torch.tanh(d * c)
return d + (e + e)
print(torch.jit.script(f).code)
and I received
Traceback (most recent call last):
File "revisble.py", line 21, in <module>
print(torch.jit.script(f).code)
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1226, in script
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1075, in _compile_and_register_class
ast = get_jit_class_def(obj, obj.__name__)
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 148, in get_jit_class_def
self_name=self_name) for method in methods]
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 148, in <listcomp>
self_name=self_name) for method in methods]
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 169, in get_jit_def
return build_def(ctx, py_ast.body[0], type_line, self_name)
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 198, in build_def
param_list = build_param_list(ctx, py_def.args, self_name)
File "/Users/***/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 224, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
at /Users/***/anaconda3/lib/python3.7/site-packages/torch/autograd/function.py:26:25
def mark_dirty(self, *args):
~~~~~ <--- HERE
r"""Marks given tensors as modified in an in-place operation.
**This should be called at most once, only from inside the**
:func:`forward` **method, and all arguments should be inputs.**
Every tensor that's been modified in-place in a call to :func:`forward`
should be given to this function, to ensure correctness of our checks.
It doesn't matter whether the function is called before or after
modification.
'mul2' is being compiled since it was called from 'f'
at revisble.py:17:4
def f(a, b):
c = a + b
d = mul2.apply(c)
~~~~~~~~~~~~~~~~ <--- HERE
e = torch.tanh(d * c)
return d + (e + e)
torch.jit.trace
does work either. May I know do script mode support custom ops? If so, what is the correct way to handle custom op?
Thanks!