Hessian vector product

Hey all,

I need to determine the largest eigenvalue of the Hessian of my loss-function. As my model is pretty large, i cannot compute the hessian directly, it’s too expensive. In my code I try to estimate the largest eigenvalue using the power iteration method and torch.autograd.grad to compute the hessian vector product.

However my function is currently incorrect, it sometimes converges to negative values (The Hessian is positive definite, the largest eigenvalue must be positive). Is the calculation of the Hessian vector product correct? I am especially unsure about the part where I flatten the output of torch.autograd.grad to obtain the vector.

Maybe this has numerial issues, so if somebody could confirm correctness of the Hessian vector product, it would be great.

Many thanks in advance!!

All the best,
Lukas

def estimate_hessian_eigenvalue(self, loss, params, device, tol=1e-4, max_iter=100, mode="largest"):
    """estimates the largest singular value based on power iteration"""
    # get number of params
    num_param = sum(p.numel() for p in params)
    # Calculate the gradient of the loss with respect to the model parameters
    #print(params)
    grad_params = torch.autograd.grad(loss, list(params), create_graph=True)
    #print("grad_params unfalttened:",grad_params)
    grad_params = torch.cat([e.flatten() for e in grad_params]) # flatten
    #print("grad_params:",grad_params)
    # Compute the vector product of the Hessian and a random vector using the power iteration method
    v = torch.rand(num_param).to(device)
    v = v/torch.norm(v)
    #print(v)      
    Hv = torch.autograd.grad(grad_params, list(params), v, retain_graph=True)
    #print("Hv:",Hv)
    Hv = torch.cat([e.flatten() for e in Hv]) # flatten
    #print("Hv:",Hv)       
    # normalize Hv
    Hv = Hv /torch.norm(Hv)
    for i in range(max_iter):
        # Compute the vector product of the (inverse Hessian or) Hessian and Hv 
        w = torch.autograd.grad(grad_params, list(params), Hv, retain_graph=True)
        w = torch.cat([e.flatten() for e in w]) # flatten
        # Calculate the Rayleigh quotient to estimate the largest eigenvalue of the Hessian (inverse Hessian)
        eigenvalue = torch.dot(Hv, w)/ torch.dot(Hv, Hv) 
        # Check if the difference between consecutive estimates is below the tolerance level
        if i > 0 and torch.abs(eigenvalue - last_eigenvalue) < tol:
            print("tolerance reached")
            break
        last_eigenvalue = eigenvalue
        # Update Hv for the next iteration
        Hv = w/torch.norm(w)        
    return eigenvalue

Hi Lukas!

It is not true, in general, that the Hessian is positive definite – it can certainly
have negative eigenvalues.

(The Hessian is symmetric, so your computed Hessian should be symmetric up
to numerical round-off error. At a true (local or global) minimum, the Hessian
will be positive semi-definite – it can have zero eigenvalues in special cases.
However, even if you have trained well and you model is giving good results,
you’re still not necessarily at a minimum of your loss and the Hessian could
well have negative eigenvalues.)

Just to be clear, the “power iteration method” will converge to the eigenvalue
that is largest in absolute value (rather than the algebraically largest). So if
your method returns a negative eigenvalue of -1.002, you could still have a
positive eigenvalue of comparable, but slightly smaller magnitude, say, 1.001.

I don’t see anything wrong offhand with your code, and, in particular, your
flattening of the output of grad() seems sound. Also, the power iteration
method is pretty robust, so unless your problem is ridiculously large and / or
has a lot of near degeneracy, it should be giving you the right answer.

Again, in general the Hessian can have negative eigenvalues. If – taking this
into account – you still think you’re having problems, apply your method to an
intermediate-sized problem for which you can compute the full Hessian and
its spectrum and compare the result of your method with something like that
of torch.linalg.eigh().

Best.

K. Frank