Hi, I am trying to integrate torch’s autograd with a custom library, and I wonder if it’s possible to do the following:
a = torch.randn(5, 3, requires_grad=True)
l = a.sum()
l.backward()
When I call l.backward()
, in my understanding the backward pass will go all the way from l
back to a
. In our library, I would like this l.backward()
call to trigger some behavior of a
itself. I am trying to implement a class that inherits torch tensor and use it to replace a
, and i want it to do something when the backward pass reaches a
. Is it possible?
Something like this
class CustomTensor(torch.Tensor):
def reached_this_tensor(self):
print('reached')
a = CustomTensor([1,2,3])
l = a.sum()
l.backward()
And it will print reached
.