How do I store a graph to use autograd.grad() on it again later?

Hello Everyone,

I am attempting to make my code more efficient by storing a graph, say x, somewhere and then calling autograd.grad(x) multiple times rather than re-computing x each time I want to do it. Let me give a more concrete example.

def hvp(y, w, v):
    if len(w) != len(v):
        raise(ValueError("w and v must have the same length."))

    #First backprop
    first_grads = grad(y, w, retain_graph=True, create_graph=True)

    # Elementwise products
    elemwise_products = 0
    for grad_elem, v_elem in zip(first_grads, v):
        elemwise_products += torch.sum(grad_elem * v_elem)

    # Second backprop
    return_grads = grad(elemwise_products, w, create_graph=False)

    return_grads = To_List(return_grads).detach()

    return return_grads

Here first_grads will be the same each time I call the function (as y and w will also be the same), so it makes little sense to calculate y, w, and first_grads each time I want to call the function (y is the model loss, w contains the model parameters). Is there a way to somehow retain first_grads so I don’t have to re-compute it? I have already tried creating a class that stores y, w, first_grads and calls this hvp as one of its functions. I also tried storing the return_grads.detach() in a different variable in case detach() deleted the whole graph.

Thank you,
Ethan