Autograd graph traversal

Context

I’m developing a library on top of PyTorch for Jacobian descent, called torchjd. I have a function torchjd.backward, that I want to be quite similar to torch.backward in its interface. Basically, the difference is that torch.backward computes the sum of gradients of given tensors with respect to graph leaves, while torchjd.backward can do more complex aggregations than summing the gradients.

Problem

I don’t know how to cleanly find the default inputs with respect to which tensors need to
be differentiated. If those are specified manually, my function works, but similarly to torch.backward, I’d like the inputs parameter to default to the leaf tensors that were used to obtain the tensors parameter.
In the case of torch.backward, these leaf tensors are found through the C++ code. In my case, I would like to keep the library purely python, so I think that I need to find those tensors by traversing the backward computation graph.

Question

Given a list of tensors, how can I find all leaf tensors requiring grad that were used to compute these tensors?

Currently, our proposed implementation looks like this

def _get_leaves_of_autograd_graph(roots: list[Tensor]) -> set[Tensor]:
    """
    Gets the leafs of the autograd graph of all tensors in `roots`.

    :param roots: Roots of the autograd graphs.
    """
    nodes_to_traverse = [tensor.grad_fn for tensor in roots]
    leaves = set()

    while nodes_to_traverse:
        current_node = nodes_to_traverse.pop()

        if hasattr(current_node, "variable"):
            leaves.add(current_node.variable)
        else:
            nodes_to_traverse += [child[0] for child in current_node.next_functions if child[0] is not None]

    return leaves

I’m really worried about the hasattr(current_node, "variable") part, since I don’t think there is any guarantee that the variable of a node will be stored in the variable field.

Does anyone know how this implementation could be improved? At the moment, I don’t think it is solid enough to be integrated into our library.

Also, shouldn’t PyTorch itself provide an easier way to find such values, or a more standard way to traverse the AG graph?

Thanks for reading this long question. Any help would be really appreciated!

nodes_to_traverse = [tensor.grad_fn for tensor in roots]

one edge case to note is that if tensor itself were a leaf, grad_fn would be None

hasattr(current_node, “variable”)

I don’t expect this to change, should be fine to use it.

Also, shouldn’t PyTorch itself provide an easier way to find such values, or a more standard way to traverse the AG graph?

Sounds reasonable -

We already have this generator that is used a couple places in the code, and we can probably expose it.

def iter_graph(roots: List[Node]) -> Iterator[Node]:
        if not roots:
            return
        seen: Set[Node] = set()
        q: Deque[Node] = deque()
        for node in roots:
            if node is not None and node not in seen::
                seen.add(node)
                q.append(node)

        while q:
            node = q.popleft()
            for fn, _ in node.next_functions:
                if fn in seen or fn is None:
                    continue
                seen.add(fn)
                q.append(fn)

            yield node

This seems reasonable to have, we’d accept a PR implementing this

1 Like

Thanks a lot for the answer! I will try to make a PR in the coming days. I will tag you on it when it’s ready, so that we can keep discussing there.