I have a Tensor subclass which defines custom torch. functions. Calling some of these gives me a warning and I don’t understand why.
Here’s an example:
import torch
def add(input, other, *, alpha=1, out=None):
input = input.as_subclass(torch.Tensor)
other = other.as_subclass(torch.Tensor)
return torch.add(input, other, alpha=alpha, out=out).as_subclass(MyTensor)
def zeros(*size, dtype=None, device=None, requires_grad=False, out=None):
result = torch.zeros(size, dtype=torch.float32, device=device, requires_grad=requires_grad)
return result.as_subclass(MyTensor)
class MyTensor(torch.Tensor):
def __new__(cls, data, requires_grad=False):
return torch.Tensor._make_subclass(cls, data, requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.add:
return add(*args, **kwargs)
elif func is torch.zeros:
return zeros(*args, **kwargs)
When I call torch.add, there is no warning:
a = MyTensor(torch.randn(3, 3), requires_grad=True)
b = MyTensor(torch.randn(3, 3), requires_grad=True)
torch.add(a, b)
When I call torch.zeros, I get this warning:
torch.zeros(3, 3, dtype=MyTensor)
>>> UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:346.)
torch.zeros(3, 3, dtype=MyTensor)
The only difference I can see is that I pass the MyTensor class as the dtype rather than instances of it as in torch.add.