Hi, I’m looking to override the gradient computation (ideally in Python) at the level of the computation graph (e.g. not for a specific op like add or sin, but in the actual backpropagation process through the graph, especially: (a) when the derivative is computed for each node along a path in the computation graph via multiplication and (b) when the derivative is accumulated at a node from multiple paths by addition).
I have been playing around wtih https://github.com/albanD/subclass_zoo/blob/main/inner_autograd_tensor.py to try to override different parts of the autograd engine.
But, while I can easily override individual operations like embedding or add, it appears that torch.ops.aten._backward
is not registered as a func during the call to torch_dispatch when a user calls .backward() on a tensor in Pytorch.
I’ve included a short excerpt of the code I’m playing around with below. This part specifically contains the implementation of __torch_dispatch__
for a wrapper around torch.Tensor which can override individual ops but not the backward call.
For the full, runnable example, see this colab notebook.
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, cls):
return t.elem
elif isinstance(t, torch.Tensor) and t.requires_grad:
# If any other argument at this level does require gradients
# it will not interact with our inner Tensor and thus this
# should fail.
raise RuntimeError("Bad mixup of autograd level")
else:
return t
def wrap(t):
# Micro-optimization: not necessary to rewrap if the output tensor
# doesn't require gradients
if (
isinstance(t, torch.Tensor)
and not isinstance(t, cls)
and t.requires_grad
):
return cls(t)
else:
return t
with enable_reentrant_dispatch():
# Override gradient behavior
if func == torch.ops.aten.embedding.default:
args = fill_defaults(args, 5, [-1, False, False])
weight, indices, padding_idx, scale_grad_by_freq, _sparse = map(
unwrap, args
)
assert not kwargs
# Force sparse gradients. We could have also done this by
# defining a custom autograd function.
return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True))
if func == torch.ops.aten.add.Tensor:
# This operation is successfully overridden
print("OVERRIDING ADD")
t1, t2 = map(
unwrap, args
)
assert not kwargs
def override_add(t1, t2):
return func(t1, t2) / 10
return cls(override_add(t1, t2))
if func == torch.ops.aten._backward.default:
# this never gets called!
print("OVERRIDING BACKWARD")
return tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
How can I (if at all possible) implement __torch_dispatch__
so that the backward call is overridden and controllable by the user?
Thanks for the help!