Hi,
I want to compute the second derivative (Hessian) of the loss w.r.t. the parameters of the model. I proceed batch-wise. Here is a code to illustrate my problem:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd.functional import hessian
# simple dataset
class data_set(Dataset):
def __init__(self):
self.input = torch.randn(10,2)
self.output = torch.randn(10,1)
def __len__(self):
return len(self.input)
def __getitem__(self, index):
return self.input[index], self.output[index]
dataset = data_set()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# simple model
class Model(nn.Module):
def __init__(self, input_size, layers_size, output_size):
super(Model, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(input_size, layers_size),
nn.GELU(),
nn.Linear(layers_size, output_size) )
def forward(self, x):
return self.mlp(x)
model = Model(2,3,1)
flat_params = torch.nn.utils.parameters_to_vector(model.parameters())
loss_func = nn.MSELoss()
# compute Hessian
hess = 0.
for i, batch in enumerate(dataloader):
inputs, targets = batch
batch_size = len(inputs)
def f(theta_):
torch.nn.utils.vector_to_parameters(theta_, model.parameters())
preds = torch.func.functional_call(model, model.state_dict(), inputs)
loss = loss_func(preds, targets)
return loss
hess_batch = torch.autograd.functional.hessian(f, flat_params).detach()
hess = hess + hess_batch * batch_size
print("Hessian =", hess)
Unfortunately, the Hessian is all 0.
Either I made a silly mistake or theta_
is not “seen” in the computational graph.
Big thanks for any help.