According to the documentation, PyTorch can be extended by creating a subclass of torch.Tensor
and implementing the __torch_function__
method. This works perfectly fine in standard “eager” execution mode, however, it does not support scripting. This is a major drawback since scripting, besides its inherent benefits, is also the preferred way to export models to ONNX.
The issue is twofold:
- Methods or attributes of the original
torch.Tensor
that are modified on a subclass oftorch.Tensor
, are silently ignored when used in scripted methods/modules. Similarly, any new methods or attributes of thetorch.Tensor
subclass (not originally on thetorch.Tensor
) result in aRuntimeError
when callingtorch.jit.script
on a method or module calling them. - A Tensor subclass that has a custom
__torch_function__
implementation doesn’t get its__torch_function__
scripted into the static graph. Instead, the__torch_function__
is “outside” the graph and is called with the scripted function(s) as thefunc
argument. This removes any custom overrides that follow the “Extending PyTorch” documentation.
Issue 1 has previously been documented here, but did not seem to gain much traction. I have not yet found any previous reports of issue 2.
Illustration of issue 1
The script below is similar to the example from here.
-
The eager execution of
call_length()
prints 123456 as expected but the scripted version prints 2, matching the actual tensor but not the override. -
When scripting
call_new_method()
, ourTensorSubClass
type is ignored and inferred to be a regulartorch.Tensor
which does not have anew_method()
defined on it – and therefore scripting fails with aRuntimeError
:RuntimeError: 'Tensor' object has no attribute or method 'new_method'.: File "/home/jdh/repos/dreamstream/test_tensor_subclass.py", line 23 def call_new_method(x: TensorSubClass) -> int: return x.new_method() ~~~~~~~~~~~~ <--- HERE
Issue 1 script
import torch
class TensorSubClass(torch.Tensor):
def __init__(self, tensor: torch.Tensor):
super().__init__()
self.tensor = tensor
def __len__(self) -> int:
"""Modify the behaviour of len() to return a constant value."""
return 123456
def new_method(self) -> int:
"""Add a new method to the subclass that does not exist on `torch.Tensor`."""
return 123456
def call_length(tensor: TensorSubClass) -> int:
return len(tensor)
def call_new_method(tensor: TensorSubClass) -> int:
return tensor.new_method()
torch.manual_seed(0)
call_length_scripted = torch.jit.script(call_length)
tensor = TensorSubClass(torch.rand((2, 3, 4)))
length_from_eager = call_length(tensor)
print("Eager: ", length_from_eager) # Prints 123456
length_from_jit = call_length_scripted(tensor)
print("Scripted: ", length_from_jit) # Prints 2
call_new_method_scripted = torch.jit.script(call_new_method) # Throws RuntimeError
Issue 1 root cause
The root cause seems to be related to the is_tensor
method in the torch.jit.annotations
source code which assigns the standard TensorType
to subclasses of torch.Tensor
ignoring any modified/custom methods or attributes.
Illustration of issue 2
The script below implements a Tensor subclass with a __torch_function__
that overrides only torch.nn.functional.linear
by adding a large bias to the output. It then feeds such a tensor through a model consisting of a linear project and a sigmoid activation in eager and scripted mode.
The output shows how scripted mode calls __torch_function__
with the outer forward
method of the scripted Model and therefore executes the regular torch.nn.functional.linear
instead of our override. That is, __torch_function__
is not recorded as part of the scripting and custom behaviour gets ignored.
=== Executing eager model ===
Calling __torch_function__ for <built-in function linear> with name 'linear'
Calling function override for <built-in function linear>
Calling linear override
Calling __torch_function__ for <method 'sum' of 'torch._C._TensorBase' objects> with name 'sum'
Calling __torch_function__ for <built-in method sigmoid of type object at 0x7f192bcfd540> with name 'sigmoid'
Eager: TensorSubClass([1., 1.], grad_fn=<AliasBackward0>)
=== Executing scripted model ===
Calling __torch_function__ for <torch.ScriptMethod object at 0x7f1828d36de0> with name 'forward'
Scripted: TensorSubClass([0.5440, 0.4919], grad_fn=<AliasBackward0>)```
Issue 2 script:
from typing import Any, Union
import torch
def linear_override(input, weight, bias=None):
print("Calling linear override")
x = torch.nn.functional.linear(input, weight, bias)
return x + 10.0
OVERRIDES = {
torch.nn.functional.linear: linear_override,
}
class TensorSubClass(torch.Tensor):
def __init__(self, tensor: torch.Tensor):
super().__init__()
self.tensor = tensor
@classmethod
def to_tensor(cls, input: Any) -> Union[torch.Tensor, Any]:
return input.tensor if isinstance(input, TensorSubClass) else input
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func is cls.__repr__:
return super().__torch_function__(func, types, args, kwargs)
print(f"Calling __torch_function__ for {func} with name '{func.__name__}'")
if kwargs is None:
kwargs = {}
if func not in OVERRIDES:
return super().__torch_function__(func, types, args, kwargs)
print(f"Calling function override for {func}")
# To avoid infinite recursion we first convert all tensors to torch.Tensor.
args = torch.utils._pytree.tree_map(cls.to_tensor, args)
kwargs = torch.utils._pytree.tree_map(cls.to_tensor, kwargs)
# Call override, convert tensors to TensorSubClass and reconstruct pytrees.
out = OVERRIDES[func](*args, **kwargs)
return torch.utils._pytree.tree_map(cls, out)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 2)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x: TensorSubClass):
return self.sigmoid(self.linear(x).sum(-1))
torch.manual_seed(0)
model = Model()
scripted_model = torch.jit.script(model)
x = TensorSubClass(torch.rand((2, 4)))
print("\n=== Executing eager model ===")
out_from_eager = model(x)
# Prints `TensorSubClass([1., 1.], grad_fn=<AliasBackward0>)`
print("Eager: ", out_from_eager)
print("\n=== Executing scripted model ===")
out_from_scripted = scripted_model(x)
# Prints `TensorSubClass([0.5440, 0.4919], grad_fn=<AliasBackward0>)`
print("Scripted: ", out_from_scripted, "\n")
Potential solution for issue 2
A potential solution to issue 2 seems to be extending PyTorch via __torch_dispatch__
instead of __torch_function__
. It intercepts aten:---
operators instead of torch function calls but doesn’t seem quite as well-documented. Importantly, most of these operators are still dispatched even after scripting (those that are changed to other aten:---
operators could then also be overridden – we know the full set of aten:---
operators a priori). However, this is quite low-level so some extensions could be hard or impossible to make, or just better suited for __torch_function__
. Also, it doesn’t solve the issue 1, and it’s unclear to me whether this would work in practice, and whether it would allow ONNX export.