Efficient Methods for Computing or Approximating the Condition Number of the Hessian Matrix in Neural Networks

Hello everyone,

I’m currently working on computing the condition number of the Hessian matrix for the weights of a neural network. While I have implemented the code for this computation, I have noticed that it is quite expensive and time-consuming. Therefore, I’m reaching out to the community to inquire if there are any more efficient methods or approximations available to tackle this problem.

Here is a snippet of my code for reference:

# -- predict
y, logits = _stateless.functional_call(model, params, x.squeeze(1))

# -- loss
loss = loss_func(logits, label.ravel())

for param in [v for k, v in params.items() if 'fc' in k]:
    # -- compute gradient
    grad = torch.autograd.grad(loss, param, retain_graph=True, create_graph=True, allow_unused=True)[0].reshape(-1)

    # -- compute Hessian
    Hessian = []
    n_params = grad.shape[0]
    for j in range(n_params):
        hess = torch.autograd.grad(grad[j], param, retain_graph=True, allow_unused=True)[0].reshape(-1)
        Hessian.append(hess)
    Hessian = torch.stack(Hessian)

    # -- compute condition number
    print(torch.linalg.cond(Hessian))

I would greatly appreciate any insights, suggestions, or alternative approaches that can help me optimize the computation of the Hessian matrix or find a more affordable approximation. Thank you in advance for your help!"

Hi @blade,

Depending on where the bottleneck in your code is, you could always compute your derivatives with the torch.func package (in order to vectorize over the for j in range(n_params): loop, which will speed up your calc. In fact, the torch.func package will allow for an efficient Hessian calculation via their API and if multiple samples are needed, you can vectorize over the samples via torch.func.vmap. Both of these calls can be used compositionally too, i.e. vmap(hessian(myFunc))

As the condition number is the ratio of the largest and smallest eigenvalues (depending on what norm you choose, of course). You could compute the largest eigenvalue via the power-iteration method, but I’m not sure on how to apply such method for the smallest eigenvalue (as that requires the inverse of your Hessian matrix, which scales cubically with the size of the Hessian matrix).