Efficient computation of Hessian with respect to network weights using autograd.grad and symmetry of Hessian matrix

As the title says, I am trying to compute the Hessian with respect to the weights (not the inputs) of the network. Right now, my code is as follows:

def hessian(loss, params):
      grads = autograd.grad(loss, params, retain_graph=True, create_graph=True)
      flattened_grads = torch.cat(([grad.flatten() for grad in grads]))

      hessian = torch.zeros(flattened_grads.shape[0], flattened_grads.shape[0])

      for idx, grad in enumerate(grads):
           second_der = autograd.grad(grad, params, retain_graph=True, allow_unused=True)
           second_der = torch.cat(([grad.flatten() for grad in second_der]))
           hessian[idx, :] = second_der

That works fine, but it is inefficient because the hessian matrix is symmetric, so the whole matrix does not need to be calculated.

To make it more efficient I thought of the following:

flat_params = torch.cat(([par.flatten() for par in params]))
for idx, grad in enumerate(grads):
     second_der = autograd.grad(grad, flat_params[idx::], retain_graph=True, allow_unused=True)
     second_der = torch.cat(([grad.flatten() for grad in second_der]))
     hessian[idx, :] = second_der

Namely, I calculate the upper triangular matrix and the lower half can be copied later, but unfortunately second_der = autograd.grad(grad, flat_params[idx::], retain_graph=True, allow_unused=True) throws None. How to fix this? is there I a way can do what I want?

1 Like

If you’re trying to compute the Hessian of a matrix effeciently, I’d recommend using the functorch package. You can find their repo here, and their documentation on higher-order gradients here.

is functorch much more optimized than autograd.grad? I thought functorch was more to enable per batch analysis, i.e., getting the gradient per sample instead of the sum of all gradients, but it was as fast as autograd.grad.

It’s not just for per-sample gradients (although functorch’s vmap will handle that very easily). I’d recommend giving the documentation a brief read, one of its main features is efficient Hessian calculation. The documentation (with example code) can be found here Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch 0.2.0 documentation

1 Like

We can technically use this if it is a simple function that takes some dtype objects as parameters so that we can compute Hessian with respect to these parameters. However, I’m not sure how to use this when the parameters are encapsulated in a model.
For example, if the function we consider is a loss function where we first compute the outputs of the model by doing a forward pass, and then the loss using the outputs, we may pass model parameters to the function, but is there a way to perform forward pass by model parameters themselves?

Hi @maheshakya,

Can you give a clear example of this? I think what you’re referring to is torch.func.functional_call. If your loss function is written in the standard PyTorch syntax is defined as,

outputs = net(inputs)
loss = loss_fn(outputs, targets)

In the torch.func namespace, you would define it along the lines of,

params = dict(net.named_parameters())
outputs = torch.func.functional_call(net, params, inputs)
loss = loss_fn(outputs, targets)

This example is for a single sample, for multiple samples you’ll need to wrap the outputs calculations in a torch.func.vmap call to vectorize over the batch-dimension.

1 Like

Thanks for your reply @AlphaBetaGamma96. I have the following example in standard PyTorch syntax:
Suppose net is some pretrained model and loss_fn is some loss function.

outputs = net(inputs)
loss = loss_fn(outputs, targets)

Now I want to compute the Hessian of loss w.r.t. parameters of net. Presumably, I would do something like below using functorch.

from functorch import hessian

def compute_loss(net, inputs, targets):
    outputs = net(inputs)
    return loss_fn(outputs, targets)

h = hessian(compute_loss, argnums=0)(net, inputs, targets)

Can I use hessian from functorch like this?

The documentation of the hessian function for functorch can be found here: Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch nightly documentation

I would advise using torch.func instead (if you’re using a more recent version of PyTorch), but something similar is possible in functorch. In torch.func an example of computing the Hessian with respect to parameters would be something like,

import torch
from torch import nn
from torch.func import functional_call, vmap, hessian

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.fc1=nn.Linear(1,16)
    self.fc2=nn.Linear(16,1)
    self.af=nn.Tanh()
  def forward(self, x):
    x=self.fc1(x)
    x=self.af(x)
    x=self.fc2(x)
    return x.squeeze(-1)

net = Model()

batch_size=100

targets = torch.randn(batch_size)
inputs = torch.randn(batch_size, 1)
params = dict(net.named_parameters())

def fcall(params, inputs):
  outputs = functional_call(net, params, inputs)
  return outputs

def loss_fn(outputs, targets):
  return torch.mean((outputs - targets)**2, dim=0)

def compute_loss(params, inputs, targets):
  outputs = vmap(fcall, in_dims=(None,0))(params, inputs) #vectorize over batch
  return loss_fn(outputs, targets)
  
def compute_hessian_loss(params, inputs, targets):
  return hessian(compute_loss, argnums=(0))(params, inputs, targets)

loss = compute_loss(params, inputs, targets)
print(loss)

hess = compute_hessian_loss(params, inputs, targets)
key=list(params.keys())[0] #take weight in first layer as example key
print(hess[key][key].shape) #Hessian of loss w.r.t first weight (shape [16, 1, 16, 1])

1 Like

This approach works. Thanks!

1 Like