TL;DR - How do you implement a custom jvp method for a custom.autograd.Function
?
Hi All,
I’ve been trying to make a forward-over-reverse function to efficiently compute the Laplacian of a given function, and I’ve been expanding upon what was discussed here
I understand that using forward-over-reverse isn’t vectorized so I won’t get the ideal speed-up but I just wanted to test it out anyway. My current version of the function is here,
def laplacian_from_log_foward_reverse(func, xs):
with fwAD.dual_level():
jacobian = torch.zeros(*xs.shape, device=xs.device, dtype=xs.dtype)
laplacian = torch.zeros(*xs.shape, device=xs.device, dtype=xs.dtype)
for i in range(xs.shape[-1]):
tangent = torch.zeros_like(xs) #was zeros_like?
tangent[:,i] = 1 #mark the index for forward-ad
dual_in = fwAD.make_dual(xs, tangent)
dual_out = func(dual_in)
primal_out, tangent_out = fwAD.unpack_dual(dual_out)
jacobian[:,i] = tangent_out
out = torch.autograd.grad(tangent_out, xs, torch.ones_like(tangent_out), retain_graph=True, create_graph=True)[0]
laplacian[:,i] = out[:,i]
return torch.sum(laplacian + jacobian.pow(2), dim=-1)
When I pass in my function (which is a nn.Module
) it has a custom.autograd.Function
within it and the function fails. I know the function works, as testing it on,
def func(x):
return x.pow(2).sum(dim=-1)
works completely fine and matches reverse-over-reverse methods, albeit slower. But how do you define a custom jvp function so that forward_ad can differentiate the function?
Any help would be greatly appreciated!