Custom computation of only some gradients


I’m gonna define a function with two arguments. I need to implement the derivateive w.r.t. one of them myself, but let the autograd engine calculate the other derivateive. I’ve tried grad, vjp, jacfwd, and jacrev from the functorch package to leave the second gradient to autograd. They do work, but are too slow. Here is an example:

def f(x, y):
    return x**2 * torch.sin(y)

class F(Function):
    def forward(ctx, x, y):
        output = f(x, y)  # x**2 * torch.sin(y)
        ctx.save_for_backward(x, y)
        return output

    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        grad_y = x**2 * torch.cos(y)
        grad_x = jacfwd(f, argnums=0)(x, y)
        return grad_x, grad_y

The above code is just a simple example. I defined how to differentiate w.r.t. y, but had autograd do this for x.
My actual code is longer and more costly to run, though. Is there any way I could do this?