Hi, how should I save the computation graph of a gradient vector computed from torch.autograd.grad(loss, model.parameters(), create_graph=True)
?
The background is that I want to compute the Hessian-vector products of k
vectors: H V
, in which H
is the Hessian of a neural network with n
parameters, and V
is a constant matrix with n
rows and k
columns. To do that, I compute the gradient of the inner product between gradient of the network forward function g
and V
, with respect to the network parameters. An example that works for a tiny network is
import torch
# define the tiny "network"
class quadratic_fun(torch.nn.Module):
def __init__(self):
super(quadratic_fun, self).__init__()
self.x = torch.nn.Parameter(torch.ones(5, requires_grad=True))
self.y = torch.nn.Parameter(torch.ones(5, requires_grad=True))
def forward(self):
loss = torch.norm(self.x) ** 2 + torch.norm(self.y) ** 2
return loss
# compute the flattened gradient with create_graph=True
model = quadratic_fun()
loss_quad = model.forward()
grad_ft = torch.autograd.grad(loss_quad, model.parameters(), create_graph=True)
flat_grad = torch.cat([g.contiguous().view(-1) for g in grad_ft])
# generate the constant matrix V, and compute the matrix-gradient product
torch.manual_seed(0)
V = torch.randn((10, 3))
h = torch.matmul(flat_grad, V)
# compute the matrix-Jacobian product by iterating over the columns of the constant matrix
for i in range(3):
hvp = torch.autograd.grad(h[i], model.parameters(), retain_graph=True)
hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])
print(hvp_flat)
which gives
tensor([-2.2517, -0.8678, -0.6320, -2.5267, 0.2397, -0.2232, -0.9854, 0.2248,
-0.2046, 0.1050])
tensor([-2.3047, 1.6974, -4.2304, 0.7000, 2.4753, -1.2272, 0.4968, -1.6821,
1.5849, 1.0457])
tensor([-0.5012, 1.3840, 0.6445, 0.6163, -0.2869, 0.0632, 0.8794, -4.6321,
-0.5793, 4.6044])
However, this is not feasible on CUDA when H
is the Hessian of a large neural network: with retain_graph=True
in the third from last line, the CUDA memory will quickly be filled up. While if I don’t retain the graph, the graph will be freed after one iteration of the for loop. In that case, I would need to compute the gradient again, which is time-consuming. Thus I wonder if I can save the not only the gradient value, but also its associated computation graph (both generated from grad_ft = torch.autograd.grad(loss_quad, model.parameters(), create_graph=True)
) to a file or buffer, and reload it in a later iteration of the for loop.
Some other posts I looked into but didn’t find an answer:
- This post suggests using JIT, but it is not clear to me how to use the API for the graph of a gradient vector.
- A reply in this post suggests to compute the matrix-Jacobian product with
torch.autograd.functional.jacobian
, but it looks like the API only works when the function to compute Jacobian is explicitly defined.)
Thanks!