Hello guys!
I recently have an efficiency problem when I’m going to compute the element-wise gradients.
To make my problem clear, assume there is an input tensor A in shape [M]. The computational graph is like
B=model(A) # A, B are both in the shape of [M]
Now I’d like the gradient as sum(dA[i]/dB[i])
and sum(d^2A[i]/dB[i]^2)
, and I implemented it as
import torch
sum_1st_grad = 0
sum_2nd_grad = 0
for i in range(A.shape[0]): # A.shape[0]=M
grad1 = torch.autograd.grad(B[i], A[i], retain_graph=True)[0]
grad2 = torch.autograd.grad(grad1, A[i], retain_graph=True)[0]
sum_1st_grad += grad1
sum_2nd_grad += grad2
but it is SUPER slow as you may expect.
Have you ever had similar problems? Do you have any idea for a more efficient implementation?
Thanks!