I’m gonna define a function with two arguments. I need to implement the derivateive w.r.t. one of them myself, but let the autograd engine calculate the other derivateive. I’ve tried grad, vjp, jacfwd, and jacrev from the functorch package to leave the second gradient to autograd. They do work, but are too slow. Here is an example:
def f(x, y): return x**2 * torch.sin(y) class F(Function): @staticmethod def forward(ctx, x, y): output = f(x, y) # x**2 * torch.sin(y) ctx.save_for_backward(x, y) return output @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors grad_y = x**2 * torch.cos(y) grad_x = jacfwd(f, argnums=0)(x, y) return grad_x, grad_y
The above code is just a simple example. I defined how to differentiate w.r.t. y, but had autograd do this for x.
My actual code is longer and more costly to run, though. Is there any way I could do this?