TorchScript ignores user-defined subclasses of Tensor and custom implementations of `__torch_function__`

According to the documentation, PyTorch can be extended by creating a subclass of torch.Tensor and implementing the __torch_function__ method. This works perfectly fine in standard “eager” execution mode, however, it does not support scripting. This is a major drawback since scripting, besides its inherent benefits, is also the preferred way to export models to ONNX.

The issue is twofold:

  1. Methods or attributes of the original torch.Tensor that are modified on a subclass of torch.Tensor, are silently ignored when used in scripted methods/modules. Similarly, any new methods or attributes of the torch.Tensor subclass (not originally on the torch.Tensor) result in a RuntimeError when calling torch.jit.script on a method or module calling them.
  2. A Tensor subclass that has a custom __torch_function__ implementation doesn’t get its __torch_function__ scripted into the static graph. Instead, the __torch_function__ is “outside” the graph and is called with the scripted function(s) as the func argument. This removes any custom overrides that follow the “Extending PyTorch” documentation.

Issue 1 has previously been documented here, but did not seem to gain much traction. I have not yet found any previous reports of issue 2.

Illustration of issue 1
The script below is similar to the example from here.

  • The eager execution of call_length() prints 123456 as expected but the scripted version prints 2, matching the actual tensor but not the override.

  • When scripting call_new_method(), our TensorSubClass type is ignored and inferred to be a regular torch.Tensor which does not have a new_method() defined on it – and therefore scripting fails with a RuntimeError:

    RuntimeError:
    'Tensor' object has no attribute or method 'new_method'.:
      File "/home/jdh/repos/dreamstream/test_tensor_subclass.py", line 23
    def call_new_method(x: TensorSubClass) -> int:
        return x.new_method()
               ~~~~~~~~~~~~ <--- HERE
    

Issue 1 script

import torch


class TensorSubClass(torch.Tensor):
    def __init__(self, tensor: torch.Tensor):
        super().__init__()
        self.tensor = tensor

    def __len__(self) -> int:
        """Modify the behaviour of len() to return a constant value."""
        return 123456

    def new_method(self) -> int:
        """Add a new method to the subclass that does not exist on `torch.Tensor`."""
        return 123456


def call_length(tensor: TensorSubClass) -> int:
    return len(tensor)


def call_new_method(tensor: TensorSubClass) -> int:
    return tensor.new_method()


torch.manual_seed(0)

call_length_scripted = torch.jit.script(call_length)

tensor = TensorSubClass(torch.rand((2, 3, 4)))

length_from_eager = call_length(tensor)
print("Eager: ", length_from_eager)  # Prints 123456
length_from_jit = call_length_scripted(tensor)
print("Scripted: ", length_from_jit)  # Prints 2

call_new_method_scripted = torch.jit.script(call_new_method)  # Throws RuntimeError

Issue 1 root cause
The root cause seems to be related to the is_tensor method in the torch.jit.annotations source code which assigns the standard TensorType to subclasses of torch.Tensor ignoring any modified/custom methods or attributes.

Illustration of issue 2
The script below implements a Tensor subclass with a __torch_function__ that overrides only torch.nn.functional.linear by adding a large bias to the output. It then feeds such a tensor through a model consisting of a linear project and a sigmoid activation in eager and scripted mode.

The output shows how scripted mode calls __torch_function__ with the outer forward method of the scripted Model and therefore executes the regular torch.nn.functional.linear instead of our override. That is, __torch_function__ is not recorded as part of the scripting and custom behaviour gets ignored.

=== Executing eager model ===
Calling __torch_function__ for <built-in function linear> with name 'linear'
Calling function override for <built-in function linear>
Calling linear override
Calling __torch_function__ for <method 'sum' of 'torch._C._TensorBase' objects> with name 'sum'
Calling __torch_function__ for <built-in method sigmoid of type object at 0x7f192bcfd540> with name 'sigmoid'
Eager:  TensorSubClass([1., 1.], grad_fn=<AliasBackward0>)

=== Executing scripted model ===
Calling __torch_function__ for <torch.ScriptMethod object at 0x7f1828d36de0> with name 'forward'
Scripted:  TensorSubClass([0.5440, 0.4919], grad_fn=<AliasBackward0>)```

Issue 2 script:

from typing import Any, Union

import torch


def linear_override(input, weight, bias=None):
    print("Calling linear override")
    x = torch.nn.functional.linear(input, weight, bias)
    return x + 10.0


OVERRIDES = {
    torch.nn.functional.linear: linear_override,
}
    

class TensorSubClass(torch.Tensor):
    def __init__(self, tensor: torch.Tensor):
        super().__init__()
        self.tensor = tensor

    @classmethod
    def to_tensor(cls, input: Any) -> Union[torch.Tensor, Any]:
        return input.tensor if isinstance(input, TensorSubClass) else input

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if func is cls.__repr__:
            return super().__torch_function__(func, types, args, kwargs)

        print(f"Calling __torch_function__ for {func} with name '{func.__name__}'")
        if kwargs is None:
            kwargs = {}

        if func not in OVERRIDES:
            return super().__torch_function__(func, types, args, kwargs)

        print(f"Calling function override for {func}")

        # To avoid infinite recursion we first convert all tensors to torch.Tensor. 
        args = torch.utils._pytree.tree_map(cls.to_tensor, args)
        kwargs = torch.utils._pytree.tree_map(cls.to_tensor, kwargs)

        # Call override, convert tensors to TensorSubClass and reconstruct pytrees.
        out = OVERRIDES[func](*args, **kwargs)
        return torch.utils._pytree.tree_map(cls, out)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 2)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x: TensorSubClass):
        return self.sigmoid(self.linear(x).sum(-1))


torch.manual_seed(0)

model = Model()
scripted_model = torch.jit.script(model)

x = TensorSubClass(torch.rand((2, 4)))

print("\n=== Executing eager model ===")
out_from_eager = model(x)
# Prints `TensorSubClass([1., 1.], grad_fn=<AliasBackward0>)`
print("Eager: ", out_from_eager)

print("\n=== Executing scripted model ===")
out_from_scripted = scripted_model(x)
# Prints `TensorSubClass([0.5440, 0.4919], grad_fn=<AliasBackward0>)`
print("Scripted: ", out_from_scripted, "\n")

Potential solution for issue 2
A potential solution to issue 2 seems to be extending PyTorch via __torch_dispatch__ instead of __torch_function__. It intercepts aten:--- operators instead of torch function calls but doesn’t seem quite as well-documented. Importantly, most of these operators are still dispatched even after scripting (those that are changed to other aten:--- operators could then also be overridden – we know the full set of aten:--- operators a priori). However, this is quite low-level so some extensions could be hard or impossible to make, or just better suited for __torch_function__. Also, it doesn’t solve the issue 1, and it’s unclear to me whether this would work in practice, and whether it would allow ONNX export.