I’m reading the Extending torch.autograd
doc, and the Double Backward with Custom Functions tutorial. I think I have a grasp of the idea and why it matters for higher order gradient computations. That said, I have a feeling that once_differentiable
provides very limited helps in terms of debugging and error prompting. Consider this toy x ** 2
example (a simplified version of the not-to-do example from the tutorial):
import torch
class BadSquare(torch.autograd.Function):
# This is an example of what NOT to do!
@staticmethod
def forward(ctx, x):
x2 = x * 2
ctx.save_for_backward(x2)
return x ** 2
@staticmethod
def backward(ctx, grad_out):
x2, = ctx.saved_tensors
return grad_out * x2
x = torch.rand(2, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(BadSquare.apply, x) # first order check passes
torch.autograd.gradgradcheck(BadSquare.apply, x) # fails, grad(grad(x)) Jacobian mismatch
Instead of ctx saving x
and using it in the backward call, it creates an intermediate tensor x2
which is detached from the computational graph (in the no-grad forward context) and use that in backward (bad!), making x.grad
independent from x
. Hence the grad(grad(x)) check failed.
Now, since this BadSquare
function doesn’t support double backward, I mark the backward as once_differentiable:
@staticmethod
@torch.autograd.function.once_differentiable # add this
def backward(ctx, grad_out):
x2, = ctx.saved_tensors
return grad_out * x2
However the only thing this decorator does is to check whether grad_out
requires_grad in this call (it’ll error if yes). The thing is, for almost all of such “bad examples”, the key limitation is that we have some intermediate results that are hard (or impossible) to be associated to the input x via the standard computational graph (e.g. we may need some third-party approach to compute the grad). However, this has nothing to do with grad_out
and it barely helps to know whether grad_out needs grad or not (Side question: is there even a case where grad_out may require grad? ).
If I compute the Hessian:
def bad_sum_squared(x: torch.Tensor):
return BadSquare.apply(x).sum()
x = torch.rand(2, requires_grad=True, dtype=torch.double)
h = torch.autograd.functional.hessian(bad_sum_squared, x)
print(h) # all zeros
It gives all zeros both with and without the once_differentiable
decorator. There’s no error prompted with the decorator. It just gives a wrong result. So to me this behaves like we already have a backward that doesn’t support higher order derivatives – the computed results are considered as “const” w.r.t. the inputs (like linear/affine functions). Adding the decorator does nothing more than an emphasis that “we gave up”. It doesn’t produce errors as intended.
What am I missing here?