UserWarning for defining `__torch_function__` as a plain method

I have a Tensor subclass which defines custom torch. functions. Calling some of these gives me a warning and I don’t understand why.

Here’s an example:

import torch

def add(input, other, *, alpha=1, out=None):
    input = input.as_subclass(torch.Tensor)
    other = other.as_subclass(torch.Tensor)
    return torch.add(input, other, alpha=alpha, out=out).as_subclass(MyTensor)

def zeros(*size, dtype=None, device=None, requires_grad=False, out=None):
    result = torch.zeros(size, dtype=torch.float32, device=device, requires_grad=requires_grad)
    return result.as_subclass(MyTensor)

class MyTensor(torch.Tensor):

    def __new__(cls, data, requires_grad=False):
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        if func is torch.add:
            return add(*args, **kwargs)
        elif func is torch.zeros:
            return zeros(*args, **kwargs)

When I call torch.add, there is no warning:

a = MyTensor(torch.randn(3, 3), requires_grad=True)
b = MyTensor(torch.randn(3, 3), requires_grad=True)
torch.add(a, b)

When I call torch.zeros, I get this warning:

torch.zeros(3, 3, dtype=MyTensor)
>>> UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:346.)
  torch.zeros(3, 3, dtype=MyTensor)

The only difference I can see is that I pass the MyTensor class as the dtype rather than instances of it as in torch.add.

1 Like

That warning is thrown here pytorch/torch/csrc/utils/python_arg_parser.cpp at 79fc0a9141e026e4c237b25e8f69f2b21c55e013 · pytorch/pytorch · GitHub

The reason why you see this in

torch.zeros(3, 3, dtype=MyTensor)

Is as you mentioned, you are passing the class, MyTensor (rather than an instance of it) as an arg. Since __torch_function__ is (correctly) a classmethod, MyTensor.__torch_function__.__self__ is MyTensor (the class), which causes a false positive for the warning.

From the __torch_function__ docs, it seems like passing instances of a class might be the intended or more common use case for torch function

PyTorch will invoke your __torch_function__ implementation when an instance of your custom class is passed to a function in the torch namespace

2 Likes