Tensor subclasses don't dispatch to __torch_function__ from inside backward hooks. Is this expected?

I’ve been experimenting with __torch_function__ on both Tensor subclasses and Tensor-like objects. I’ve found that __torch_function__ isn’t dispatched to when calling functions in the torch namespace from inside hooks. For example:

class MyTensor(torch.Tensor):

    def __new__(cls, data):
        return torch.Tensor._make_subclass(cls, data, data.requires_grad)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        print(f"[MyTensor __torch_function__] func={func.__name__}")
        return super().__torch_function__(func, types, args, kwargs)

x = MyTensor(torch.randn(3, requires_grad=True))
y = MyTensor(torch.randn(3, requires_grad=True))

def post_accum_hook(param):
    print("Start of hook")
    param + MyTensor(torch.ones_like(param))
    print("End of hook")

x.register_post_accumulate_grad_hook(post_accum_hook)

z = torch.add(x, y)

torch.sum(z).backward()
>>> [MyTensor __torch_function__] func=register_post_accumulate_grad_hook
>>> [MyTensor __torch_function__] func=add
>>> [MyTensor __torch_function__] func=sum
>>> [MyTensor __torch_function__] func=backward
>>> Start of hook
>>> End of hook

If the __torch_function__ dispatch mechanism were active inside the hook, we would expect:

>>> Start of hook
>>> [MyTensor __torch_function__] func=add
>>> End of hook

Comparatively, the __torch_function__ mechanism works fine with Tensor-like objects:

class MyTensorLike:

    def __init__(self, data):
        self.data = data

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        print(f"[MyTensorLike __torch_function__] func={func.__name__}")
        result = func(*(arg.data if isinstance(arg, MyTensorLike) else arg for arg in args), **kwargs)
        return cls(result)

a = MyTensorLike(torch.randn(3, requires_grad=True))
b = MyTensorLike(torch.randn(3, requires_grad=True))

def post_accum_grad_hook(param):
    print("Start of hook")
    param + MyTensorLike(torch.ones_like(param.data))
    print("End of hook")

a.data.register_post_accumulate_grad_hook(post_accum_grad_hook)

c = torch.add(a, b)
torch.sum(c).data.backward()
>>> [MyTensorLike __torch_function__] func=add
>>> [MyTensorLike __torch_function__] func=sum
>>> Start of hook
>>> [MyTensorLike __torch_function__] func=add
>>> End of hook

I’ve found that you can re-enable the __torch_function__ dispatch mechanism using the torch._C._EnableTorchFunction context manager:

def post_accum_hook(param):
    print("Start of hook")
    with torch._C._EnableTorchFunction():
        param + MyTensor(torch.ones_like(param))
    print("End of hook")
>>> Start of hook
>>> [MyTensor __torch_function__] func=add
>>> End of hook

However, I’m wondering if this is the expected functionality and whether it would make more sense for this to be re-enabled automatically before calling the hook to match the functionality for Tensor-like objects. I couldn’t find any other threads or docs mentioning this anywhere. Let me know any thoughts/ideas you have. Thanks.

1 Like

I believe this happens because for MyTensor when you call .backward(), super().__torch_function__(func, types, args, kwargs) will be run under with _C.DisableTorchFunctionSubclass():, the backward hook runs within this scope, as a result torch function is disabled.

On the other hand for MyTensorLike, there is no such super(). call, so the backward hook is run with torch function enabled

Thanks for providing the code snippet. It makes sense that the function is called with the _C.DisableTorchFunctionSubclass context since otherwise there would be an infinite recursion calling the subclass’ __torch_function__. My question is more focused on whether a change should be made to re-enable this when calling the backward hooks from inside the autograd engine rather than making the user do it themselves.

I see, do feel free to file an issue