Taking nth Derivatives of x(t)

Hello,

I have a neural network function x(t) which accepts a scalar t and returns a scalar x. I have 5000 time points t_1, …, t_m and would like to compute the nth order derivatives of x with respect to t at these different time points.

Currently, I am using the naive approach of computing the nth derivative wrt. t at each t_i sample point separately as described here: python - Higher order gradients in pytorch - Stack Overflow
but this is prohibitively slow!

I have looked into potential ways of making torch.autograd.grad accept batches of input points (t_1, …, t_m) but with no success. It seems like this functionality is not implemented but I was hoping there is a better approach than computing the nth derivative with respect to each sample in a for loop.

Thank you!

1 Like

We actually just added that to autograd.grad : torch.autograd.grad — PyTorch master documentation

You can use the is_grads_batched flag to specify that the grads you give have an extra batch.

1 Like

I feel like is_grads_batched is not doing exactly what I was expecting. Namely, given a neural network x(t), it is computing the derivative at one fixed point x’(t*) and then for a batch of inputs t_1, …, t_m it is computing the product x’(t*) t_1, …, x’(t*) t_m.

What I am trying to do is compute x’(t_1), …, x’(t_m). In other words, I would hope to compute the derivative of the neural network x(t) at a batch of sample points t_1, …, t_m. Is there a batched way to do this in pytorch?

Turns out all I needed was a function like this:

def deriv(self, t, n=1):
        dx = self.forward(t)
        for i in range(n):
            dx = torch.autograd.grad(outputs=dx, inputs=t, grad_outputs=torch.ones_like(t), create_graph=True)[0]
        return dx

Hi,

You can have your function be able to take a batch of inputs (most provided functions handle that already). And then backward through that.

Not sure what your code sample is doing since i is not used inside the loop.

i is not used inside the loop because it is just computing the nth derivative by differentiating the previous value of dx.

Ho I didn’t understood your notation sorry… Yes that would be the way to get the nth derivative.