Hello everyone,
I’m currently working on computing the condition number of the Hessian matrix for the weights of a neural network. While I have implemented the code for this computation, I have noticed that it is quite expensive and time-consuming. Therefore, I’m reaching out to the community to inquire if there are any more efficient methods or approximations available to tackle this problem.
Here is a snippet of my code for reference:
# -- predict
y, logits = _stateless.functional_call(model, params, x.squeeze(1))
# -- loss
loss = loss_func(logits, label.ravel())
for param in [v for k, v in params.items() if 'fc' in k]:
# -- compute gradient
grad = torch.autograd.grad(loss, param, retain_graph=True, create_graph=True, allow_unused=True)[0].reshape(-1)
# -- compute Hessian
Hessian = []
n_params = grad.shape[0]
for j in range(n_params):
hess = torch.autograd.grad(grad[j], param, retain_graph=True, allow_unused=True)[0].reshape(-1)
Hessian.append(hess)
Hessian = torch.stack(Hessian)
# -- compute condition number
print(torch.linalg.cond(Hessian))
I would greatly appreciate any insights, suggestions, or alternative approaches that can help me optimize the computation of the Hessian matrix or find a more affordable approximation. Thank you in advance for your help!"