Find if one tensor is in the autograd history of another

I’d like to check if one tensor is anywhere in the autograd history of another tensor. Is this possible?

Use case: I’m calling torch.autograd.grad on both tensors, and want to avoid double-counting the gradient on the tensor that’s further back in the history. My proposed fix is to (before the forward pass) check if one tensor is in the history of the other, and if it is, transform it through an identity (x = x.view(*x.shape)).

I’m not sure to understand this? Since autograd.grad returns the gradient (and don’t accumulate like .backward()) why would that lead to double counting?

I’d like to check if one tensor is anywhere in the autograd history of another tensor. Is this possible?

It is a bit tricky because the backward graph is not a graph of Tensors.
But it could be possible to know if there is a link. Note that this link might still lead to a gradient of 0 but we wouldn’t be able to know that :smiley:

My idea would be to use base and target and see if target is in the history of base.
You can find where target’s will be used by checking target.grad_fn.
Then you can traverse the graph from base’s grad_fn (with grad_fn.next_functions) and see if you stumble upon target’s grad_fn.
Note that this won’t work if target is a leaf. In that case, you will need to traverse all the grad_fn from the base and for each AccumulateGrad Node that you encounter, check if node.variable is target.

Thanks for your response! What you’re saying sounds like it might work; I’ll give it a try.

In terms of why this leads to double counting - I’m calling autograd.grad with these tensors as the inputs, and the next thing that’s going to happen is that (essentially) .backward(calculated_gradient) will be called on each of those tensors.

(Actually this is all happening inside a autograd.Function.backward and the calculated gradients will be passed onto those tensors for the autograd framework to compute in the usual way. This is all part of some custom-backward memory saving recomputation shenanigans.)

ok!

In that case yes, the idea above will work for now. Note that these are “internal” api and might not be around forever. But they are fairly stable :smiley: