Hi, I’m looking to override the gradient computation (ideally in Python) at the level of the computation graph (e.g. not for a specific op like add or sin, but in the actual backpropagation process through the graph, especially: (a) when the derivative is computed for each node along a path in the computation graph via multiplication and (b) when the derivative is accumulated at a node from multiple paths by addition).
I have been playing around wtih https://github.com/albanD/subclass_zoo/blob/main/inner_autograd_tensor.py to try to override different parts of the autograd engine.
But, while I can easily override individual operations like embedding or add, it appears that
torch.ops.aten._backward is not registered as a func during the call to torch_dispatch when a user calls .backward() on a tensor in Pytorch.
I’ve included a short excerpt of the code I’m playing around with below. This part specifically contains the implementation of
__torch_dispatch__ for a wrapper around torch.Tensor which can override individual ops but not the backward call.
For the full, runnable example, see this colab notebook.
@classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(t): if isinstance(t, cls): return t.elem elif isinstance(t, torch.Tensor) and t.requires_grad: # If any other argument at this level does require gradients # it will not interact with our inner Tensor and thus this # should fail. raise RuntimeError("Bad mixup of autograd level") else: return t def wrap(t): # Micro-optimization: not necessary to rewrap if the output tensor # doesn't require gradients if ( isinstance(t, torch.Tensor) and not isinstance(t, cls) and t.requires_grad ): return cls(t) else: return t with enable_reentrant_dispatch(): # Override gradient behavior if func == torch.ops.aten.embedding.default: args = fill_defaults(args, 5, [-1, False, False]) weight, indices, padding_idx, scale_grad_by_freq, _sparse = map( unwrap, args ) assert not kwargs # Force sparse gradients. We could have also done this by # defining a custom autograd function. return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True)) if func == torch.ops.aten.add.Tensor: # This operation is successfully overridden print("OVERRIDING ADD") t1, t2 = map( unwrap, args ) assert not kwargs def override_add(t1, t2): return func(t1, t2) / 10 return cls(override_add(t1, t2)) if func == torch.ops.aten._backward.default: # this never gets called! print("OVERRIDING BACKWARD") return tree_map( wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) )
How can I (if at all possible) implement
__torch_dispatch__ so that the backward call is overridden and controllable by the user?
Thanks for the help!