In what situation does the decorator once_differentiable help?

I’m reading the Extending torch.autograd doc, and the Double Backward with Custom Functions tutorial. I think I have a grasp of the idea and why it matters for higher order gradient computations. That said, I have a feeling that once_differentiable provides very limited helps in terms of debugging and error prompting. Consider this toy x ** 2 example (a simplified version of the not-to-do example from the tutorial):

import torch

class BadSquare(torch.autograd.Function):
    # This is an example of what NOT to do!
    @staticmethod
    def forward(ctx, x):
        x2 = x * 2
        ctx.save_for_backward(x2)
        return x ** 2

    @staticmethod
    def backward(ctx, grad_out):
        x2, = ctx.saved_tensors
        return grad_out * x2

x = torch.rand(2, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(BadSquare.apply, x)  # first order check passes
torch.autograd.gradgradcheck(BadSquare.apply, x)  # fails, grad(grad(x)) Jacobian mismatch

Instead of ctx saving x and using it in the backward call, it creates an intermediate tensor x2 which is detached from the computational graph (in the no-grad forward context) and use that in backward (bad!), making x.grad independent from x. Hence the grad(grad(x)) check failed.

Now, since this BadSquare function doesn’t support double backward, I mark the backward as once_differentiable:

    @staticmethod
    @torch.autograd.function.once_differentiable  # add this
    def backward(ctx, grad_out):
        x2, = ctx.saved_tensors
        return grad_out * x2

However the only thing this decorator does is to check whether grad_out requires_grad in this call (it’ll error if yes). The thing is, for almost all of such “bad examples”, the key limitation is that we have some intermediate results that are hard (or impossible) to be associated to the input x via the standard computational graph (e.g. we may need some third-party approach to compute the grad). However, this has nothing to do with grad_out and it barely helps to know whether grad_out needs grad or not (Side question: is there even a case where grad_out may require grad? ).

If I compute the Hessian:

def bad_sum_squared(x: torch.Tensor):
    return BadSquare.apply(x).sum()

x = torch.rand(2, requires_grad=True, dtype=torch.double)
h = torch.autograd.functional.hessian(bad_sum_squared, x)
print(h)  # all zeros

It gives all zeros both with and without the once_differentiable decorator. There’s no error prompted with the decorator. It just gives a wrong result. So to me this behaves like we already have a backward that doesn’t support higher order derivatives – the computed results are considered as “const” w.r.t. the inputs (like linear/affine functions). Adding the decorator does nothing more than an emphasis that “we gave up”. It doesn’t produce errors as intended.

What am I missing here?

1 Like

Another post mentioning once_differentiable:

We say that it prevents higher order derivatives computation, but like I mentioned above, I doubt if it makes any difference if we just remove the decorator from the example code in the post – in either case, the higher order derivatives are still “computable”, it’s just all zeros (and again, with or without the decorator).

If you had a function for which higher derivatives exist, but you silently return incorrect results (e.g. zero), then once_differentiable could be a good way to explicitly raise an error when someone tries to access higher-order derivatives.

The thing is, if someone tries to access higher-order derivatives via double backward, then it must be that there’s a grad_output in the backward param list that requires grad. However, that’s often not the case when someone tries to encapsulate fancy/complicated logics in the forward/backward implementation, where the only params in the backward call are anyways irrelevant to higher order derivative computations. This makes the once_differentiable decorator pretty much useless. E.g. I have given an example in the main post that it doesn’t stop wrong Hessian computation. Another example is the post above, which I also believe it does nothing, because the only access exposed to the outside is grad_output, which are always empty in higher order derivative computations.