I was tinkering around trying to understand how autograd works in the background. For functions that output a single tensor, its easier to interpret how the backward pass goes, but it gets more complex (IMO) when there’s multiple outputs. To get an idea of what’s happening I created a custom function that takes an input and outputs two values (the double and triple of the input respectively). According to the autograd custom Function
rules, the number of grad_ouputs
passed to the backward()
method should be the number of outputs (obviously). However, what I noticed is that autograd somehow knew which gradient to pass to what argument when only on the two outputs called backward()
. To get a sense of what I mean, here’s some code:
import torch
from torch.autograd.function import Function
class DoubleTriple(Function):
@staticmethod
def forward(ctx, a):
ctx.save_for_backward(a)
return a * 2, a * 3
@staticmethod
def backward(ctx, grad1, grad2):
print(grad1, grad2)
return grad1 * 2 + grad2 * 3
def dt(a):
return DoubleTriple.apply(a)
def main():
a = torch.tensor(3.0, requires_grad=True)
b, c = dt(a)
b.backward()
if __name__ == "__main__":
main()
Output:
tensor(1.) tensor(0.)
My question is how does autograd know that b
is the first gradient given both the grad_fn
attributes for b
and c
point to the same object? I tried looking more into Node
class but there wasn’t anything I could find that indicated which gradient (implicitly passed from .backward()
) should go to which grad argument for my custom backward function; it seems to just know.