Is there any way to detect when a backward pass reaches a certain node?

Hi, I am trying to integrate torch’s autograd with a custom library, and I wonder if it’s possible to do the following:

a = torch.randn(5, 3, requires_grad=True)

l = a.sum()
l.backward()

When I call l.backward(), in my understanding the backward pass will go all the way from l back to a. In our library, I would like this l.backward() call to trigger some behavior of a itself. I am trying to implement a class that inherits torch tensor and use it to replace a, and i want it to do something when the backward pass reaches a. Is it possible?

Something like this

class CustomTensor(torch.Tensor):

    def reached_this_tensor(self):
        print('reached')

a = CustomTensor([1,2,3])
l = a.sum()
l.backward()

And it will print reached.

tensor.register_hook might work:

a = torch.randn(5, 3, requires_grad=True)
a.register_hook(lambda grad: print("reached a with grad {}".format(grad)))

l = a.sum()
l.backward()
# reached a with grad tensor([[1., 1., 1.],
#         [1., 1., 1.],
#         [1., 1., 1.],
#         [1., 1., 1.],
#         [1., 1., 1.]])

Oh wow!! Thanks so much!