Hello Everyone,
I am attempting to make my code more efficient by storing a graph, say x, somewhere and then calling autograd.grad(x) multiple times rather than re-computing x each time I want to do it. Let me give a more concrete example.
def hvp(y, w, v):
if len(w) != len(v):
raise(ValueError("w and v must have the same length."))
#First backprop
first_grads = grad(y, w, retain_graph=True, create_graph=True)
# Elementwise products
elemwise_products = 0
for grad_elem, v_elem in zip(first_grads, v):
elemwise_products += torch.sum(grad_elem * v_elem)
# Second backprop
return_grads = grad(elemwise_products, w, create_graph=False)
return_grads = To_List(return_grads).detach()
return return_grads
Here first_grads will be the same each time I call the function (as y and w will also be the same), so it makes little sense to calculate y, w, and first_grads each time I want to call the function (y is the model loss, w contains the model parameters). Is there a way to somehow retain first_grads so I don’t have to re-compute it? I have already tried creating a class that stores y, w, first_grads and calls this hvp as one of its functions. I also tried storing the return_grads.detach() in a different variable in case detach() deleted the whole graph.
Thank you,
Ethan