How to compute the element-wise gradients efficiently?

Hello guys!

I recently have an efficiency problem when I’m going to compute the element-wise gradients.

To make my problem clear, assume there is an input tensor A in shape [M]. The computational graph is like

B=model(A)  # A, B are both in the shape of [M]

Now I’d like the gradient as sum(dA[i]/dB[i]) and sum(d^2A[i]/dB[i]^2), and I implemented it as

import torch
sum_1st_grad = 0
sum_2nd_grad = 0
for i in range(A.shape[0]):   # A.shape[0]=M
    grad1 = torch.autograd.grad(B[i],  A[i], retain_graph=True)[0]
    grad2 = torch.autograd.grad(grad1, A[i], retain_graph=True)[0]
    sum_1st_grad += grad1
    sum_2nd_grad += grad2

but it is SUPER slow as you may expect.

Have you ever had similar problems? Do you have any idea for a more efficient implementation?


Did you find a solution?

you can check out torch.vmap, which is a new feature that came out few months ago

I guess that is similar to functorch.vmap, right? Should even be the same underlying code.
Check this notebook, I’m not able to calculate gradients with it, it’s too memory intensive