TL;DR - How do you implement a custom jvp method for a
I’ve been trying to make a forward-over-reverse function to efficiently compute the Laplacian of a given function, and I’ve been expanding upon what was discussed here
I understand that using forward-over-reverse isn’t vectorized so I won’t get the ideal speed-up but I just wanted to test it out anyway. My current version of the function is here,
def laplacian_from_log_foward_reverse(func, xs): with fwAD.dual_level(): jacobian = torch.zeros(*xs.shape, device=xs.device, dtype=xs.dtype) laplacian = torch.zeros(*xs.shape, device=xs.device, dtype=xs.dtype) for i in range(xs.shape[-1]): tangent = torch.zeros_like(xs) #was zeros_like? tangent[:,i] = 1 #mark the index for forward-ad dual_in = fwAD.make_dual(xs, tangent) dual_out = func(dual_in) primal_out, tangent_out = fwAD.unpack_dual(dual_out) jacobian[:,i] = tangent_out out = torch.autograd.grad(tangent_out, xs, torch.ones_like(tangent_out), retain_graph=True, create_graph=True) laplacian[:,i] = out[:,i] return torch.sum(laplacian + jacobian.pow(2), dim=-1)
When I pass in my function (which is a
nn.Module) it has a
custom.autograd.Function within it and the function fails. I know the function works, as testing it on,
def func(x): return x.pow(2).sum(dim=-1)
works completely fine and matches reverse-over-reverse methods, albeit slower. But how do you define a custom jvp function so that forward_ad can differentiate the function?
Any help would be greatly appreciated!