Context
I’m developing a library on top of PyTorch for Jacobian descent, called torchjd. I have a function torchjd.backward
, that I want to be quite similar to torch.backward
in its interface. Basically, the difference is that torch.backward
computes the sum of gradients of given tensors with respect to graph leaves, while torchjd.backward
can do more complex aggregations than summing the gradients.
Problem
I don’t know how to cleanly find the default inputs
with respect to which tensors
need to
be differentiated. If those are specified manually, my function works, but similarly to torch.backward
, I’d like the inputs
parameter to default to the leaf tensors that were used to obtain the tensors
parameter.
In the case of torch.backward
, these leaf tensors are found through the C++ code. In my case, I would like to keep the library purely python, so I think that I need to find those tensors by traversing the backward computation graph.
Question
Given a list of tensors, how can I find all leaf tensors requiring grad that were used to compute these tensors?
Currently, our proposed implementation looks like this
def _get_leaves_of_autograd_graph(roots: list[Tensor]) -> set[Tensor]:
"""
Gets the leafs of the autograd graph of all tensors in `roots`.
:param roots: Roots of the autograd graphs.
"""
nodes_to_traverse = [tensor.grad_fn for tensor in roots]
leaves = set()
while nodes_to_traverse:
current_node = nodes_to_traverse.pop()
if hasattr(current_node, "variable"):
leaves.add(current_node.variable)
else:
nodes_to_traverse += [child[0] for child in current_node.next_functions if child[0] is not None]
return leaves
I’m really worried about the hasattr(current_node, "variable")
part, since I don’t think there is any guarantee that the variable of a node will be stored in the variable
field.
Does anyone know how this implementation could be improved? At the moment, I don’t think it is solid enough to be integrated into our library.
Also, shouldn’t PyTorch itself provide an easier way to find such values, or a more standard way to traverse the AG graph?
Thanks for reading this long question. Any help would be really appreciated!