The documentation of the hessian
function for functorch
can be found here: Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch nightly documentation
I would advise using torch.func
instead (if you’re using a more recent version of PyTorch), but something similar is possible in functorch
. In torch.func
an example of computing the Hessian with respect to parameters would be something like,
import torch
from torch import nn
from torch.func import functional_call, vmap, hessian
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1=nn.Linear(1,16)
self.fc2=nn.Linear(16,1)
self.af=nn.Tanh()
def forward(self, x):
x=self.fc1(x)
x=self.af(x)
x=self.fc2(x)
return x.squeeze(-1)
net = Model()
batch_size=100
targets = torch.randn(batch_size)
inputs = torch.randn(batch_size, 1)
params = dict(net.named_parameters())
def fcall(params, inputs):
outputs = functional_call(net, params, inputs)
return outputs
def loss_fn(outputs, targets):
return torch.mean((outputs - targets)**2, dim=0)
def compute_loss(params, inputs, targets):
outputs = vmap(fcall, in_dims=(None,0))(params, inputs) #vectorize over batch
return loss_fn(outputs, targets)
def compute_hessian_loss(params, inputs, targets):
return hessian(compute_loss, argnums=(0))(params, inputs, targets)
loss = compute_loss(params, inputs, targets)
print(loss)
hess = compute_hessian_loss(params, inputs, targets)
key=list(params.keys())[0] #take weight in first layer as example key
print(hess[key][key].shape) #Hessian of loss w.r.t first weight (shape [16, 1, 16, 1])