Can you override torch.ops.aten._backward using __torch_dispatch__?

Hi, I’m looking to override the gradient computation (ideally in Python) at the level of the computation graph (e.g. not for a specific op like add or sin, but in the actual backpropagation process through the graph, especially: (a) when the derivative is computed for each node along a path in the computation graph via multiplication and (b) when the derivative is accumulated at a node from multiple paths by addition).

I have been playing around wtih https://github.com/albanD/subclass_zoo/blob/main/inner_autograd_tensor.py to try to override different parts of the autograd engine.

But, while I can easily override individual operations like embedding or add, it appears that torch.ops.aten._backward is not registered as a func during the call to torch_dispatch when a user calls .backward() on a tensor in Pytorch.

I’ve included a short excerpt of the code I’m playing around with below. This part specifically contains the implementation of __torch_dispatch__ for a wrapper around torch.Tensor which can override individual ops but not the backward call.

For the full, runnable example, see this colab notebook.

@classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            if isinstance(t, cls):
                return t.elem
            elif isinstance(t, torch.Tensor) and t.requires_grad:
                # If any other argument at this level does require gradients
                # it will not interact with our inner Tensor and thus this
                # should fail.
                raise RuntimeError("Bad mixup of autograd level")
            else:
                return t

        def wrap(t):
            # Micro-optimization: not necessary to rewrap if the output tensor
            # doesn't require gradients
            if (
                isinstance(t, torch.Tensor)
                and not isinstance(t, cls)
                and t.requires_grad
            ):
                return cls(t)
            else:
                return t

        with enable_reentrant_dispatch():
            # Override gradient behavior
            if func == torch.ops.aten.embedding.default:
                args = fill_defaults(args, 5, [-1, False, False])
                weight, indices, padding_idx, scale_grad_by_freq, _sparse = map(
                    unwrap, args
                )
                assert not kwargs
                # Force sparse gradients.  We could have also done this by
                # defining a custom autograd function.
                return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True))
            if func == torch.ops.aten.add.Tensor:
                # This operation is successfully overridden
                print("OVERRIDING ADD")
                t1, t2 = map(
                    unwrap, args
                )

                assert not kwargs
                def override_add(t1, t2):
                  return func(t1, t2) / 10
                return cls(override_add(t1, t2))
            
            if func == torch.ops.aten._backward.default:
                # this never gets called!
                print("OVERRIDING BACKWARD") 

            return tree_map(
                wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
            )

How can I (if at all possible) implement __torch_dispatch__ so that the backward call is overridden and controllable by the user?

Thanks for the help!

Hi @kdu,

Have you had a look at using hooks for solving your problem? That can allow for in-place modification of gradients within the Python API. The docs are here: Module — PyTorch 2.0 documentation

Hi @AlphaBetaGamma96, thanks for the reply! I did take a look and I’m not sure if it fits my use case. To my understanding, it looks like the register_full_backward_pre_hook allows the user to define a function executed before backprop is run through the graph of the module and register_full_backward_hook applies a function after backpropping through the graph of the module. While these hooks are useful to modifying the gradient of the figurative “start” and “end” of the module, it doesn’t seem to let me modify the gradients of nodes in between.

That is, I’d like to be able to modify how the backprop actually executes through each node within the module, but I think these hooks only give me access to modifying the gradient of the head node(s) of the module (grad_output of register_full_backward_pre_hook) and the gradient of the input nodes (grad_input of register_full_backward_hook). Does that match your understanding of what hooks can/cannot do?

Hi @kdu,

The hooks allow you to access the gradient of the loss with respect to the output (i.e. grad_output), or the inputs to a module (input) for the full_backward and forward_pre hooks respectively.

I assume when you state nodes here, it’s nodes within the graph itself and not nodes within the network? If so, you can still use hooks, but you’ll need to use the torch.autograd.grad.Node.register_hook function, I think. The docs are here: torch.autograd.graph.Node.register_hook — PyTorch 2.1 documentation

Thank @AlphaBetaGamma96 - I’ve been playing around with it and I think it might be the answer, but I’m not sure if it lets me a few things, in particular whether I can override how the gradient of a node is accumulated from the gradients of the parent node. I opened a separate question here about that because it felt sufficiently self-contained Access gradient of parents of a node during first execution of backprop.