# Taking derivatives of a neural network wrt parameters after first being differentiated wrt state

I have a neural network: `ANN(x)` that is a function of x but also has parameters to optimize. I need to calculate the derivative of it with respect to x then I would like to gradient descent and optimize `dANN(x)dx.` This requires taking the derivative of `dANN(x)/dx` with respect to the parameters of `ANN(x)`.

I can do this with the autograd `jacobian` function but that’s really slow. I would like to do this with the faster grad or backward functions and `torch.autograd.Function` but I’m not sure how to use the `backward` function so that each parameter in the `dANN(x)dx` gets the correct gradient, especially since the input to the `forward` function is a list of parameters (ie `ANN.paramaeters()`) each which should receive their own gradient.

For example I have:

``````class derivativeANN(torch.autograd.Function):

@staticmethod
def forward(ctx, input):
ans = model(tempstate2.reshape(-1,8))[:,1]
ctx.save_for_backward(input)
ctx.save_for_backward(ds)
ctx.save_for_backward(model.parameters())
return ds[0][:,:2]

@staticmethod
input, ds, params12 = ctx.saved_tensors
dspar = []
for each in input:
``````

But now I need each element of dspar to update the grade of each element of the parameter list of `list(model.parameter())`. I’m sure there is a way to optimize `torch.autograd.Function` with multiple inputs, I just don’t know how.

Hi @Cameron_Fen,

If you want to compute the derivative of the output of your network with respect to its input and parameters efficiently, you can use the functorch library (which comes with pytorch now). A brief example code can be shown below.

``````import torch
from functorch import make_functional, jacrev, vmap

net = Model(*args, **kwargs)

fnet, params = make_functional(net)

B=100 #number of samples
size = 4 #4 input nodes (for example)

x = torch.rand(B, size)

results = vmap(jacrev(jacrev(fnet, argnums=(1)), argnums=(0)), in_dims=(None, 0))(params, x)
``````

If you have more outputs than inputs, you should use `jacfwd` instead of `jacrev` for forward-mode AD.

A jax-like pytorch library was exactly what I needed. Thanks!