How does autograd know which gradient to pass to backward?

I was tinkering around trying to understand how autograd works in the background. For functions that output a single tensor, its easier to interpret how the backward pass goes, but it gets more complex (IMO) when there’s multiple outputs. To get an idea of what’s happening I created a custom function that takes an input and outputs two values (the double and triple of the input respectively). According to the autograd custom Function rules, the number of grad_ouputs passed to the backward() method should be the number of outputs (obviously). However, what I noticed is that autograd somehow knew which gradient to pass to what argument when only on the two outputs called backward(). To get a sense of what I mean, here’s some code:

import torch
from torch.autograd.function import Function


class DoubleTriple(Function):

    @staticmethod
    def forward(ctx, a):
        ctx.save_for_backward(a)
        return a * 2, a * 3

    @staticmethod
    def backward(ctx, grad1, grad2):
        print(grad1, grad2)
        return grad1 * 2 + grad2 * 3


def dt(a):
    return DoubleTriple.apply(a)


def main():

    a = torch.tensor(3.0, requires_grad=True)
    b, c = dt(a)
    b.backward()


if __name__ == "__main__":
    main()

Output:

tensor(1.) tensor(0.)

My question is how does autograd know that b is the first gradient given both the grad_fn attributes for b and c point to the same object? I tried looking more into Node class but there wasn’t anything I could find that indicated which gradient (implicitly passed from .backward()) should go to which grad argument for my custom backward function; it seems to just know.

In python if you write

args = [1,2,3]
def function(a,b,c):
      print(a,b,c)
function(*args)

Output will be : 1, 2, 3

Perhaps here the grads are passed in such a way?
Is this what you are asking?

Some what sort of, I think this actually gives me some intuition. Maybe autograd makes a default for the grad_outputs and then just populates the indices to which ever gradient is provided. My question is mainly asking how autograd knows b is the first argument to my function and the reason why I ask, even though it seems obvious is there’s no where in their source code that I could find that shows b being marked as the first argument. Theoretically, it could just be storing the argument’s order in the forward pass them mapping them whenever the backward() method, but I myself can’t prove that. I just want to understand where this is done and how it works out of general curiosity.

Update: it’s in my face. It appears to be an attribute of the Tensor class that gets assigned at some point during the forward pass.

torch.autograd.graph.py

def get_gradient_edge(tensor):
    """Get the gradient edge for computing the gradient of the given Tensor.

    In particular, it is equivalent to call
    ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
    """
    if not tensor.requires_grad:
        raise RuntimeError(
            "It is not possible to get the gradient edge for a Tensor that does not require gradients"
        )
    grad_fn = _get_grad_fn_or_grad_acc(tensor)

    # Note that output_nr default to 0 which is the right value
    # for the AccumulateGrad node.
    return GradientEdge(grad_fn, tensor.output_nr)

1 Like