Efficient computation of Hessian with respect to network weights using autograd.grad and symmetry of Hessian matrix

The documentation of the hessian function for functorch can be found here: Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch nightly documentation

I would advise using torch.func instead (if you’re using a more recent version of PyTorch), but something similar is possible in functorch. In torch.func an example of computing the Hessian with respect to parameters would be something like,

import torch
from torch import nn
from torch.func import functional_call, vmap, hessian

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.fc1=nn.Linear(1,16)
    self.fc2=nn.Linear(16,1)
    self.af=nn.Tanh()
  def forward(self, x):
    x=self.fc1(x)
    x=self.af(x)
    x=self.fc2(x)
    return x.squeeze(-1)

net = Model()

batch_size=100

targets = torch.randn(batch_size)
inputs = torch.randn(batch_size, 1)
params = dict(net.named_parameters())

def fcall(params, inputs):
  outputs = functional_call(net, params, inputs)
  return outputs

def loss_fn(outputs, targets):
  return torch.mean((outputs - targets)**2, dim=0)

def compute_loss(params, inputs, targets):
  outputs = vmap(fcall, in_dims=(None,0))(params, inputs) #vectorize over batch
  return loss_fn(outputs, targets)
  
def compute_hessian_loss(params, inputs, targets):
  return hessian(compute_loss, argnums=(0))(params, inputs, targets)

loss = compute_loss(params, inputs, targets)
print(loss)

hess = compute_hessian_loss(params, inputs, targets)
key=list(params.keys())[0] #take weight in first layer as example key
print(hess[key][key].shape) #Hessian of loss w.r.t first weight (shape [16, 1, 16, 1])

1 Like