I’ve been experimenting with __torch_function__ on both Tensor subclasses and Tensor-like objects. I’ve found that __torch_function__ isn’t dispatched to when calling functions in the torch namespace from inside hooks. For example:
class MyTensor(torch.Tensor):
def __new__(cls, data):
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
print(f"[MyTensor __torch_function__] func={func.__name__}")
return super().__torch_function__(func, types, args, kwargs)
x = MyTensor(torch.randn(3, requires_grad=True))
y = MyTensor(torch.randn(3, requires_grad=True))
def post_accum_hook(param):
print("Start of hook")
param + MyTensor(torch.ones_like(param))
print("End of hook")
x.register_post_accumulate_grad_hook(post_accum_hook)
z = torch.add(x, y)
torch.sum(z).backward()
>>> [MyTensor __torch_function__] func=register_post_accumulate_grad_hook
>>> [MyTensor __torch_function__] func=add
>>> [MyTensor __torch_function__] func=sum
>>> [MyTensor __torch_function__] func=backward
>>> Start of hook
>>> End of hook
If the __torch_function__ dispatch mechanism were active inside the hook, we would expect:
>>> Start of hook
>>> [MyTensor __torch_function__] func=add
>>> End of hook
Comparatively, the __torch_function__ mechanism works fine with Tensor-like objects:
class MyTensorLike:
def __init__(self, data):
self.data = data
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
print(f"[MyTensorLike __torch_function__] func={func.__name__}")
result = func(*(arg.data if isinstance(arg, MyTensorLike) else arg for arg in args), **kwargs)
return cls(result)
a = MyTensorLike(torch.randn(3, requires_grad=True))
b = MyTensorLike(torch.randn(3, requires_grad=True))
def post_accum_grad_hook(param):
print("Start of hook")
param + MyTensorLike(torch.ones_like(param.data))
print("End of hook")
a.data.register_post_accumulate_grad_hook(post_accum_grad_hook)
c = torch.add(a, b)
torch.sum(c).data.backward()
>>> [MyTensorLike __torch_function__] func=add
>>> [MyTensorLike __torch_function__] func=sum
>>> Start of hook
>>> [MyTensorLike __torch_function__] func=add
>>> End of hook
I’ve found that you can re-enable the __torch_function__ dispatch mechanism using the torch._C._EnableTorchFunction context manager:
def post_accum_hook(param):
print("Start of hook")
with torch._C._EnableTorchFunction():
param + MyTensor(torch.ones_like(param))
print("End of hook")
>>> Start of hook
>>> [MyTensor __torch_function__] func=add
>>> End of hook
However, I’m wondering if this is the expected functionality and whether it would make more sense for this to be re-enabled automatically before calling the hook to match the functionality for Tensor-like objects. I couldn’t find any other threads or docs mentioning this anywhere. Let me know any thoughts/ideas you have. Thanks.