hi all, I am trying to define a new @ operator in a class, then use it in torchScript model, but it failed. However when I test on add operater, it works. Here is the demo below:
Source code:
import torch
import traceback
class NewTensor:
def __init__(self, value: torch.Tensor):
self.value = value
def __matmul__(self, other: "NewTensor") -> torch.Tensor:
return -torch.matmul(self.value, other.value)
def __add__(self, other: "NewTensor") -> torch.Tensor:
return -torch.add(self.value, other.value)
class mul_model:
def __init__(self) -> None:
pass
def forward(self, x: NewTensor, y: NewTensor):
return x @ y
class add_model:
def __init__(self) -> None:
pass
def forward(self, x: NewTensor, y: NewTensor):
return x + y
model_add = add_model()
try:
a_s = torch.jit.script(model_add)
print("add model success")
except Exception as e:
print("add model fail")
traceback.print_exc()
model_mul = mul_model()
try:
a_s = torch.jit.script(model_mul)
print("mul model success")
except:
print("mul model fail")
traceback.print_exc()
Output:
add model success
mul model fail
Traceback (most recent call last):
File "test.py", line 42, in <module>
a_s = torch.jit.script(model_mul)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py", line 1351, in script
return torch.jit._recursive.create_script_class(obj)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 428, in create_script_class
_compile_and_register_class(type(obj), rcb, qualified_class_name)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 46, in _compile_and_register_class
script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::matmul(Tensor self, Tensor other) -> Tensor:
Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55951834bdb0)'.
aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55951834bdb0)'.
The original call is:
File "test.py", line 21
def forward(self, x: NewTensor, y: NewTensor):
return x @ y
~~~~~ <--- HERE
I’m unsure if TorchScript is supposed to support custom classes, so could you check if creating a custom Tensor class via __torch_function__ would work instead as described here?
class mul_model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: NewTensor, y: NewTensor):
return x @ y
class add_model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: NewTensor, y: NewTensor):
return x + y
add model success
Traceback (most recent call last):
File "test.py", line 97, in <module>
a_s = torch.jit.script(model_mul)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py", line 1286, in script
return torch.jit._recursive.create_script_module(
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 460, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 526, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 377, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::matmul(Tensor self, Tensor other) -> Tensor:
Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55af73107d90)'.
aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55af73107d90)'.
The original call is:
File "test.py", line 21
def forward(self, x: NewTensor, y: NewTensor):
return x @ y
~~~~~ <--- HERE
It is weird that torchScipt fails on matmul but works on add operator.
I tried __torch_function__ but it cannot be scripted by TorchScript because TorchScript analyzes the code statically.
Basically, I want to define a custom class NewTensor that can be scripted by TorchScript following the practice of TorchScript custom classes. The NewTensor has both __add__ and __matmul__ functions to support + and @ operators. However, we found that TorchScript works well when scripting + but fails on @. I am wondering whether this is a bug of TorchScript.
In the code above, if I use x + y where x and y are NewTensor, TorchScript can recognize the operand as NewTensor.__add__(x, y). But if I use x @ y, TorchScript will raise an error as above.