Hessian of loss w.r.t. model parameters is all 0

Hi,

I want to compute the second derivative (Hessian) of the loss w.r.t. the parameters of the model. I proceed batch-wise. Here is a code to illustrate my problem:


import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd.functional import hessian

# simple dataset
class data_set(Dataset):
    def __init__(self):
        self.input = torch.randn(10,2)
        self.output = torch.randn(10,1)
        
    def __len__(self):
        return len(self.input)

    def __getitem__(self, index):
        return self.input[index], self.output[index]

dataset = data_set()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# simple model
class Model(nn.Module):
    def __init__(self, input_size, layers_size, output_size):
        super(Model, self).__init__()
        self.mlp = nn.Sequential(
                         nn.Linear(input_size, layers_size),
                         nn.GELU(),
                         nn.Linear(layers_size, output_size) )
    
    def forward(self, x):
        return self.mlp(x)

model = Model(2,3,1)
flat_params = torch.nn.utils.parameters_to_vector(model.parameters())

loss_func = nn.MSELoss()

# compute Hessian
hess = 0.

for i, batch in enumerate(dataloader):
    inputs, targets = batch
    batch_size = len(inputs)
    
    def f(theta_):
        torch.nn.utils.vector_to_parameters(theta_, model.parameters())
        preds = torch.func.functional_call(model, model.state_dict(), inputs)
        loss = loss_func(preds, targets)
        return loss
    
    hess_batch = torch.autograd.functional.hessian(f, flat_params).detach()
    hess = hess + hess_batch * batch_size

print("Hessian =", hess)

Unfortunately, the Hessian is all 0.

Either I made a silly mistake or theta_ is not “seen” in the computational graph.
Big thanks for any help.