torchScript failed with @ __matmul__ operator

hi all, I am trying to define a new @ operator in a class, then use it in torchScript model, but it failed. However when I test on add operater, it works. Here is the demo below:

Source code:

import torch
import traceback

class NewTensor:
    def __init__(self, value: torch.Tensor):
        self.value = value  

    def __matmul__(self, other: "NewTensor") -> torch.Tensor:
        return -torch.matmul(self.value, other.value)

    def __add__(self, other: "NewTensor") -> torch.Tensor:
        return -torch.add(self.value, other.value)


class mul_model:
    def __init__(self) -> None:
        pass

    def forward(self, x: NewTensor, y: NewTensor):
        return x @ y


class add_model:
    def __init__(self) -> None:
        pass

    def forward(self, x: NewTensor, y: NewTensor):
        return x + y


model_add = add_model()
try:
    a_s = torch.jit.script(model_add)
    print("add model success")
except Exception as e:
    print("add model fail")
    traceback.print_exc()

model_mul = mul_model()
try:
    a_s = torch.jit.script(model_mul)
    print("mul model success")
except:
    print("mul model fail")
    traceback.print_exc()

Output:

add model success
mul model fail
Traceback (most recent call last):
  File "test.py", line 42, in <module>
    a_s = torch.jit.script(model_mul)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py", line 1351, in script
    return torch.jit._recursive.create_script_class(obj)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 428, in create_script_class
    _compile_and_register_class(type(obj), rcb, qualified_class_name)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 46, in _compile_and_register_class
    script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
RuntimeError:
Arguments for call are not valid.
The following variants are available:

  aten::matmul(Tensor self, Tensor other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55951834bdb0)'.

  aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55951834bdb0)'.

The original call is:
  File "test.py", line 21
    def forward(self, x: NewTensor, y: NewTensor):
        return x @ y
               ~~~~~ <--- HERE

Anyone come across similar issue?

Thanks

I’m unsure if TorchScript is supposed to support custom classes, so could you check if creating a custom Tensor class via __torch_function__ would work instead as described here?

thanks, I will try it. :grinning:

I try some thing like below, but still failed.

class mul_model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: NewTensor, y: NewTensor):
        return x @ y


class add_model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: NewTensor, y: NewTensor):
        return x + y
add model success
Traceback (most recent call last):
  File "test.py", line 97, in <module>
    a_s = torch.jit.script(model_mul)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py", line 1286, in script
    return torch.jit._recursive.create_script_module(
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 460, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 526, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 377, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Arguments for call are not valid.
The following variants are available:

  aten::matmul(Tensor self, Tensor other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55af73107d90)'.

  aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.NewTensor (of Python compilation unit at: 0x55af73107d90)'.

The original call is:
  File "test.py", line 21
    def forward(self, x: NewTensor, y: NewTensor):
        return x @ y
               ~~~~~ <--- HERE

It is weird that torchScipt fails on matmul but works on add operator.

Did you try to add this operation via a __torch_function__?

Hi @ptrblck ,

I tried __torch_function__ but it cannot be scripted by TorchScript because TorchScript analyzes the code statically.

Basically, I want to define a custom class NewTensor that can be scripted by TorchScript following the practice of TorchScript custom classes. The NewTensor has both __add__ and __matmul__ functions to support + and @ operators. However, we found that TorchScript works well when scripting + but fails on @. I am wondering whether this is a bug of TorchScript.

@torch.jit.script
class NewTensor:
    def __init__(self, value: torch.Tensor):
        self.value = value  

    def __matmul__(self, other: "NewTensor") -> torch.Tensor:
        return -torch.matmul(self.value, other.value)

    def __add__(self, other: "NewTensor") -> torch.Tensor:
        return -torch.add(self.value, other.value)

In the code above, if I use x + y where x and y are NewTensor, TorchScript can recognize the operand as NewTensor.__add__(x, y). But if I use x @ y, TorchScript will raise an error as above.