Is there a gradgradgrad check? (TripleBackward check)

Hi All,

I’m trying to validate a manually derived derivative up to its third derivative. I’ve been using torch.autograd.gradcheck and torch.autograd.gradgradcheck to check the first and second derivatives respectively (torch.autograd.gradcheck — PyTorch 2.0 documentation). However, there exists no equivalent for the third derivative?

Should I simply just compute the third order derivative and use torch.allclose between the two methods?

You could also have a function that computes the second derivative and pass that to gradcheck.
Or have a function that computes the first derivative and pass that to gradgradcheck.
gradgradcheck itself is just a function that wraps your function to do the first derivative, and then calls into gradcheck underneath

1 Like

TL;DR - Is the pseudocode I’ve shown below the correct way to computing a triple-gradcheck? I’ve tried this and my analytical backward is wrong, even though if I compute the triple-jacrev manually and compare with a pytorch primitives version of my custom function is exactly the same. So, I think my jacrev_func function may be wrong.

Hii @soulitzer,

Sorry to reopen this topic, but would it be as simple as applying torch.func.jacrev to my custom function and passing that new function to torch.autograd.gradgradcheck. For example,

x  = torch.randn(1,)
y = torch.randn(1,)

def func(x,y):
  return myCustomFunction.apply(x,y) #some custom function which has 3 custom derivatives 

def jacrev_func(x,y):
  jacrev_x, jacrev_y = torch.func.jacrev(func, argnums=(0,1))(x,y)
  return jacrev_x, jacrev_y

triple_grad_check = torch.autograd.gradgradcheck(jacrev_func, inputs=(x,y))

I only ask this as this is what I’ve done, but torch.autograd.gradgradcheck states my derivatives are wrong. I have checked my custom derivatives via calculating a triple-jacrev with my custom function (and comparing it to a pytorch primitives version), and then comparing the two results with an torch.allclose with returns True.

So, there’s a bit of confusing as to why computing the triple-jacrev directly and comparing via torch.allclose returns True, yet torch.autograd.gradgradcheck returns False.

Is the pseudocode I’ve shown above the right way to check the triple-backward?