Taking derivatives of a neural network wrt parameters after first being differentiated wrt state

I have a neural network: ANN(x) that is a function of x but also has parameters to optimize. I need to calculate the derivative of it with respect to x then I would like to gradient descent and optimize dANN(x)dx. This requires taking the derivative of dANN(x)/dx with respect to the parameters of ANN(x).

I can do this with the autograd jacobian function but that’s really slow. I would like to do this with the faster grad or backward functions and torch.autograd.Function but I’m not sure how to use the backward function so that each parameter in the dANN(x)dx gets the correct gradient, especially since the input to the forward function is a list of parameters (ie ANN.paramaeters()) each which should receive their own gradient.

For example I have:

class derivativeANN(torch.autograd.Function):
    def forward(ctx, input):
        tempstate2 = tempstate.detach().requires_grad_(True)
        ans = model(tempstate2.reshape(-1,8))[:,1]
        ds = grad(ans, tempstate2, grad_outputs=torch.ones((batch,)).to(device), retain_graph = True)
        return ds[0][:,:2]

    def backward(ctx, grad_output):
        input, ds, params12 = ctx.saved_tensors
        dspar = []
        for each in input:
            dsdpar.append(grad(ds, each, grad_outputs=torch.ones_like(ds).to(device)))

But now I need each element of dspar to update the grade of each element of the parameter list of list(model.parameter()). I’m sure there is a way to optimize torch.autograd.Function with multiple inputs, I just don’t know how.

Hi @Cameron_Fen,

If you want to compute the derivative of the output of your network with respect to its input and parameters efficiently, you can use the functorch library (which comes with pytorch now). A brief example code can be shown below.

import torch
from functorch import make_functional, jacrev, vmap

net = Model(*args, **kwargs)

fnet, params = make_functional(net)

B=100 #number of samples
size = 4 #4 input nodes (for example)

x = torch.rand(B, size)

results = vmap(jacrev(jacrev(fnet, argnums=(1)), argnums=(0)), in_dims=(None, 0))(params, x)

If you have more outputs than inputs, you should use jacfwd instead of jacrev for forward-mode AD.

A jax-like pytorch library was exactly what I needed. Thanks!