Hi All,
I’m trying to validate a manually derived derivative up to its third derivative. I’ve been using torch.autograd.gradcheck
and torch.autograd.gradgradcheck
to check the first and second derivatives respectively (torch.autograd.gradcheck — PyTorch 2.0 documentation). However, there exists no equivalent for the third derivative?
Should I simply just compute the third order derivative and use torch.allclose
between the two methods?
You could also have a function that computes the second derivative and pass that to gradcheck.
Or have a function that computes the first derivative and pass that to gradgradcheck.
gradgradcheck itself is just a function that wraps your function to do the first derivative, and then calls into gradcheck underneath
1 Like
TL;DR - Is the pseudocode I’ve shown below the correct way to computing a triple-gradcheck? I’ve tried this and my analytical backward is wrong, even though if I compute the triple-jacrev manually and compare with a pytorch primitives version of my custom function is exactly the same. So, I think my jacrev_func
function may be wrong.
Hii @soulitzer,
Sorry to reopen this topic, but would it be as simple as applying torch.func.jacrev
to my custom function and passing that new function to torch.autograd.gradgradcheck
. For example,
x = torch.randn(1,)
y = torch.randn(1,)
def func(x,y):
return myCustomFunction.apply(x,y) #some custom function which has 3 custom derivatives
def jacrev_func(x,y):
jacrev_x, jacrev_y = torch.func.jacrev(func, argnums=(0,1))(x,y)
return jacrev_x, jacrev_y
triple_grad_check = torch.autograd.gradgradcheck(jacrev_func, inputs=(x,y))
I only ask this as this is what I’ve done, but torch.autograd.gradgradcheck
states my derivatives are wrong. I have checked my custom derivatives via calculating a triple-jacrev with my custom function (and comparing it to a pytorch primitives version), and then comparing the two results with an torch.allclose
with returns True.
So, there’s a bit of confusing as to why computing the triple-jacrev directly and comparing via torch.allclose
returns True, yet torch.autograd.gradgradcheck
returns False.
Is the pseudocode I’ve shown above the right way to check the triple-backward?