I’m currently working on computing the condition number of the Hessian matrix for the weights of a neural network. While I have implemented the code for this computation, I have noticed that it is quite expensive and time-consuming. Therefore, I’m reaching out to the community to inquire if there are any more efficient methods or approximations available to tackle this problem.
Here is a snippet of my code for reference:
# -- predict y, logits = _stateless.functional_call(model, params, x.squeeze(1)) # -- loss loss = loss_func(logits, label.ravel()) for param in [v for k, v in params.items() if 'fc' in k]: # -- compute gradient grad = torch.autograd.grad(loss, param, retain_graph=True, create_graph=True, allow_unused=True).reshape(-1) # -- compute Hessian Hessian =  n_params = grad.shape for j in range(n_params): hess = torch.autograd.grad(grad[j], param, retain_graph=True, allow_unused=True).reshape(-1) Hessian.append(hess) Hessian = torch.stack(Hessian) # -- compute condition number print(torch.linalg.cond(Hessian))
I would greatly appreciate any insights, suggestions, or alternative approaches that can help me optimize the computation of the Hessian matrix or find a more affordable approximation. Thank you in advance for your help!"