Hey all,
I need to determine the largest eigenvalue of the Hessian of my loss-function. As my model is pretty large, i cannot compute the hessian directly, it’s too expensive. In my code I try to estimate the largest eigenvalue using the power iteration method and torch.autograd.grad to compute the hessian vector product.
However my function is currently incorrect, it sometimes converges to negative values (The Hessian is positive definite, the largest eigenvalue must be positive). Is the calculation of the Hessian vector product correct? I am especially unsure about the part where I flatten the output of torch.autograd.grad to obtain the vector.
Maybe this has numerial issues, so if somebody could confirm correctness of the Hessian vector product, it would be great.
Many thanks in advance!!
All the best,
Lukas
def estimate_hessian_eigenvalue(self, loss, params, device, tol=1e-4, max_iter=100, mode="largest"):
"""estimates the largest singular value based on power iteration"""
# get number of params
num_param = sum(p.numel() for p in params)
# Calculate the gradient of the loss with respect to the model parameters
#print(params)
grad_params = torch.autograd.grad(loss, list(params), create_graph=True)
#print("grad_params unfalttened:",grad_params)
grad_params = torch.cat([e.flatten() for e in grad_params]) # flatten
#print("grad_params:",grad_params)
# Compute the vector product of the Hessian and a random vector using the power iteration method
v = torch.rand(num_param).to(device)
v = v/torch.norm(v)
#print(v)
Hv = torch.autograd.grad(grad_params, list(params), v, retain_graph=True)
#print("Hv:",Hv)
Hv = torch.cat([e.flatten() for e in Hv]) # flatten
#print("Hv:",Hv)
# normalize Hv
Hv = Hv /torch.norm(Hv)
for i in range(max_iter):
# Compute the vector product of the (inverse Hessian or) Hessian and Hv
w = torch.autograd.grad(grad_params, list(params), Hv, retain_graph=True)
w = torch.cat([e.flatten() for e in w]) # flatten
# Calculate the Rayleigh quotient to estimate the largest eigenvalue of the Hessian (inverse Hessian)
eigenvalue = torch.dot(Hv, w)/ torch.dot(Hv, Hv)
# Check if the difference between consecutive estimates is below the tolerance level
if i > 0 and torch.abs(eigenvalue - last_eigenvalue) < tol:
print("tolerance reached")
break
last_eigenvalue = eigenvalue
# Update Hv for the next iteration
Hv = w/torch.norm(w)
return eigenvalue