I have a neural network function x(t) which accepts a scalar t and returns a scalar x. I have 5000 time points t_1, …, t_m and would like to compute the nth order derivatives of x with respect to t at these different time points.
Currently, I am using the naive approach of computing the nth derivative wrt. t at each t_i sample point separately as described here: python - Higher order gradients in pytorch - Stack Overflow
but this is prohibitively slow!
I have looked into potential ways of making torch.autograd.grad accept batches of input points (t_1, …, t_m) but with no success. It seems like this functionality is not implemented but I was hoping there is a better approach than computing the nth derivative with respect to each sample in a for loop.
We actually just added that to
autograd.grad : torch.autograd.grad — PyTorch master documentation
You can use the
is_grads_batched flag to specify that the grads you give have an extra batch.
I feel like is_grads_batched is not doing exactly what I was expecting. Namely, given a neural network x(t), it is computing the derivative at one fixed point x’(t*) and then for a batch of inputs t_1, …, t_m it is computing the product x’(t*) t_1, …, x’(t*) t_m.
What I am trying to do is compute x’(t_1), …, x’(t_m). In other words, I would hope to compute the derivative of the neural network x(t) at a batch of sample points t_1, …, t_m. Is there a batched way to do this in pytorch?
Turns out all I needed was a function like this:
def deriv(self, t, n=1):
dx = self.forward(t)
for i in range(n):
dx = torch.autograd.grad(outputs=dx, inputs=t, grad_outputs=torch.ones_like(t), create_graph=True)
You can have your function be able to take a batch of inputs (most provided functions handle that already). And then backward through that.
Not sure what your code sample is doing since
i is not used inside the loop.
i is not used inside the loop because it is just computing the nth derivative by differentiating the previous value of dx.
Ho I didn’t understood your notation sorry… Yes that would be the way to get the nth derivative.