Why does autograd.backward go one edge further than `inputs`?

Consider the following example (REPL pasteable)

import torch

class LoggingFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x * 2
    @staticmethod
    def backward(ctx, grad_output):
        print("Backward called on LoggingFn")
        return grad_output * 2

a = torch.tensor(1.0, requires_grad=True)
b = LoggingFn.apply(a)
c = LoggingFn.apply(b)
loss = LoggingFn.apply(c)

loss.backward(inputs=[c])

print("a.grad =", a.grad)
print("b.grad =", b.grad)
print("c.grad =", c.grad)

prints

>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> print("a.grad =", a.grad)
a.grad = None
>>> print("b.grad =", b.grad)
b.grad = None
>>> print("c.grad =", c.grad)
c.grad = tensor(2.)

LoggingFn.backward is called twice, meaning it backpropogates back to b. It was correct to not go all the way back to a, but why didn’t it stop at c? Is this incorrect behavior, or is there a reason to go one edge further than needed?

1 Like

Here’s another weird example:

>>> a = torch.tensor(1.0, requires_grad=True)
>>> b = LoggingFn.apply(a)
>>> c = LoggingFn.apply(b)
>>> loss = LoggingFn.apply(c)
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> c.grad
tensor(6.)
>>> c.backward(gradient=c.grad)
Backward called on LoggingFn
Backward called on LoggingFn
>>> a.grad
tensor(24.)
>>> c.grad
tensor(12.)

We backprop from loss to c three times, and the gradients are correctly accumulated and backproped to a. But for some reason this also changed the gradient saved in c.

I guess this one makes sense. The argument gradient=c.grad was accumulated into the already saved c.grad.

Indeed we shouldn’t compute the extra backprop. It’s just the current state due to simplicity of implementation (see the note in the doc). Feel free to file an issue if you have a perf-sensitive use case and we can probably try to fix it.

1 Like

Hi! As an additional note, this is only a problem with the backward function. If you care about performance and calling one extra node is too costly, you can also call torch.autograd.grad instead of torch.autograd.backward or tensor.backward, and the extra node will not be called.

The .grad fields will have to be accumulated manually, though.