Consider the following example (REPL pasteable)
import torch
class LoggingFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x * 2
@staticmethod
def backward(ctx, grad_output):
print("Backward called on LoggingFn")
return grad_output * 2
a = torch.tensor(1.0, requires_grad=True)
b = LoggingFn.apply(a)
c = LoggingFn.apply(b)
loss = LoggingFn.apply(c)
loss.backward(inputs=[c])
print("a.grad =", a.grad)
print("b.grad =", b.grad)
print("c.grad =", c.grad)
prints
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> print("a.grad =", a.grad)
a.grad = None
>>> print("b.grad =", b.grad)
b.grad = None
>>> print("c.grad =", c.grad)
c.grad = tensor(2.)
LoggingFn.backward
is called twice, meaning it backpropogates back to b
. It was correct to not go all the way back to a
, but why didn’t it stop at c
? Is this incorrect behavior, or is there a reason to go one edge further than needed?
1 Like
Here’s another weird example:
>>> a = torch.tensor(1.0, requires_grad=True)
>>> b = LoggingFn.apply(a)
>>> c = LoggingFn.apply(b)
>>> loss = LoggingFn.apply(c)
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> loss.backward(inputs=[c])
Backward called on LoggingFn
Backward called on LoggingFn
>>> c.grad
tensor(6.)
>>> c.backward(gradient=c.grad)
Backward called on LoggingFn
Backward called on LoggingFn
>>> a.grad
tensor(24.)
>>> c.grad
tensor(12.)
We backprop from loss
to c
three times, and the gradients are correctly accumulated and backproped to a
. But for some reason this also changed the gradient saved in c
.
I guess this one makes sense. The argument gradient=c.grad
was accumulated into the already saved c.grad
.
Indeed we shouldn’t compute the extra backprop. It’s just the current state due to simplicity of implementation (see the note in the doc). Feel free to file an issue if you have a perf-sensitive use case and we can probably try to fix it.
1 Like
Hi! As an additional note, this is only a problem with the backward
function. If you care about performance and calling one extra node is too costly, you can also call torch.autograd.grad
instead of torch.autograd.backward
or tensor.backward
, and the extra node will not be called.
The .grad
fields will have to be accumulated manually, though.