As the title says, I am trying to compute the Hessian with respect to the weights (not the inputs) of the network. Right now, my code is as follows:
def hessian(loss, params):
grads = autograd.grad(loss, params, retain_graph=True, create_graph=True)
flattened_grads = torch.cat(([grad.flatten() for grad in grads]))
hessian = torch.zeros(flattened_grads.shape[0], flattened_grads.shape[0])
for idx, grad in enumerate(grads):
second_der = autograd.grad(grad, params, retain_graph=True, allow_unused=True)
second_der = torch.cat(([grad.flatten() for grad in second_der]))
hessian[idx, :] = second_der
That works fine, but it is inefficient because the hessian matrix is symmetric, so the whole matrix does not need to be calculated.
To make it more efficient I thought of the following:
flat_params = torch.cat(([par.flatten() for par in params]))
for idx, grad in enumerate(grads):
second_der = autograd.grad(grad, flat_params[idx::], retain_graph=True, allow_unused=True)
second_der = torch.cat(([grad.flatten() for grad in second_der]))
hessian[idx, :] = second_der
Namely, I calculate the upper triangular matrix and the lower half can be copied later, but unfortunately second_der = autograd.grad(grad, flat_params[idx::], retain_graph=True, allow_unused=True) throws None. How to fix this? is there I a way can do what I want?
If you’re trying to compute the Hessian of a matrix effeciently, I’d recommend using the functorch package. You can find their repo here, and their documentation on higher-order gradients here.
is functorch much more optimized than autograd.grad? I thought functorch was more to enable per batch analysis, i.e., getting the gradient per sample instead of the sum of all gradients, but it was as fast as autograd.grad.
We can technically use this if it is a simple function that takes some dtype objects as parameters so that we can compute Hessian with respect to these parameters. However, I’m not sure how to use this when the parameters are encapsulated in a model.
For example, if the function we consider is a loss function where we first compute the outputs of the model by doing a forward pass, and then the loss using the outputs, we may pass model parameters to the function, but is there a way to perform forward pass by model parameters themselves?
Can you give a clear example of this? I think what you’re referring to is torch.func.functional_call. If your loss function is written in the standard PyTorch syntax is defined as,
outputs = net(inputs)
loss = loss_fn(outputs, targets)
In the torch.func namespace, you would define it along the lines of,
This example is for a single sample, for multiple samples you’ll need to wrap the outputs calculations in a torch.func.vmap call to vectorize over the batch-dimension.
Thanks for your reply @AlphaBetaGamma96. I have the following example in standard PyTorch syntax:
Suppose net is some pretrained model and loss_fn is some loss function.
outputs = net(inputs)
loss = loss_fn(outputs, targets)
Now I want to compute the Hessian of loss w.r.t. parameters of net. Presumably, I would do something like below using functorch.
from functorch import hessian
def compute_loss(net, inputs, targets):
outputs = net(inputs)
return loss_fn(outputs, targets)
h = hessian(compute_loss, argnums=0)(net, inputs, targets)
I would advise using torch.func instead (if you’re using a more recent version of PyTorch), but something similar is possible in functorch. In torch.func an example of computing the Hessian with respect to parameters would be something like,
import torch
from torch import nn
from torch.func import functional_call, vmap, hessian
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1=nn.Linear(1,16)
self.fc2=nn.Linear(16,1)
self.af=nn.Tanh()
def forward(self, x):
x=self.fc1(x)
x=self.af(x)
x=self.fc2(x)
return x.squeeze(-1)
net = Model()
batch_size=100
targets = torch.randn(batch_size)
inputs = torch.randn(batch_size, 1)
params = dict(net.named_parameters())
def fcall(params, inputs):
outputs = functional_call(net, params, inputs)
return outputs
def loss_fn(outputs, targets):
return torch.mean((outputs - targets)**2, dim=0)
def compute_loss(params, inputs, targets):
outputs = vmap(fcall, in_dims=(None,0))(params, inputs) #vectorize over batch
return loss_fn(outputs, targets)
def compute_hessian_loss(params, inputs, targets):
return hessian(compute_loss, argnums=(0))(params, inputs, targets)
loss = compute_loss(params, inputs, targets)
print(loss)
hess = compute_hessian_loss(params, inputs, targets)
key=list(params.keys())[0] #take weight in first layer as example key
print(hess[key][key].shape) #Hessian of loss w.r.t first weight (shape [16, 1, 16, 1])