I use follow code to evaluate Hessian matrix:
# eval Hessian matrix
def eval_hessian(loss_grad, model):
cnt = 0
for g in loss_grad:
g_vector = g.contiguous().view(-1) if cnt == 0 else torch.cat([g_vector, g.contiguous().view(-1)])
cnt = 1
l = g_vector.size(0)
hessian = torch.zeros(l, l)
for idx in range(l):
grad2rd = autograd.grad(g_vector[idx], model.parameters(), create_graph=True)
cnt = 0
for g in grad2rd:
g2 = g.contiguous().view(-1) if cnt == 0 else torch.cat([g2, g.contiguous().view(-1)])
cnt = 1
hessian[idx] = g2
return hessian.cpu().data.numpy()
where loss_grad can calculate like: autograd.grad(loss, net.parameters(), create_graph=True)
Note: it’s only for small network