I am trying to compute only the diagonal of a hessian matrix. My codes look like this:
# param is a Tensor with size (n, n), dissemble it so that we can take grad w.r.t. each individual entry
param = [[param[i][j] for i in range(n)] for j in range(n)]
for p_ in param:
for p in p_:
p.requires_grad_(requires_grad=True)
new_param = torch.stack([torch.stack(p) for p in param]) # reassemble
loss = model(inputs, param=new_param)
first_derivative = torch.autograd.grad(loss, new_param, create_graph=True)[0]
for i in range(n):
for j in range(n):
second_drivative[i][j] = torch.autograd.grad(first_derivative[i][j], param[i][j], retain_graph=True)[0]
It succeeded, but takes very long time (almost as long as computing the full Hessian). The problem seems to be these for loops, and it having to back propagate through the same graph n^2 times
Is there a way to speed it up? Or what changes should I make to PyTorch to facilitate it?
Many thanks.