Proper way to create Hessian as tensor with torch.func.hessian

When using torch.func.hessian to compute the Hessian of a model’s parameters with regard to the input loss, you end up with a nested dict.

Working with a nested dict can be cumbersome, especially for mathematical operations. I’m thus converting it to a torch.tensor, in the form of a standard symmetric Hessian matrix, using the script below.

However, this is unintuitive and requires quite a good number of lines of code.

Is there a way to do this with less lines of code and more efficiently?

import torch

model = torch.nn.Linear(2, 2)
params = dict(model.named_parameters())

x = torch.randn(2)
y = torch.tensor(1)


def f(model, params, input, target):
    predictions = torch.func.functional_call(model, params, (input,))
    return torch.nn.functional.cross_entropy(predictions, target)


hessian_dict = torch.func.hessian(f, argnums=1)(model, params, x, y)

hessian = None
# Loop over layers
for key in hessian_dict:
    param_dim = params[key].numel()

    # Using list comprehension for the inner loop (derivatives of one layer)
    concat_tensor = torch.cat(
        [hessian_dict[key][k].view(1, param_dim, -1) for k in hessian_dict[key]],
        dim=2,
    )

    # Concat all second derivatives of that layer in same dimension
    hessian = (
        torch.cat([hessian, torch.sum(concat_tensor, dim=0)], dim=0)
        if hessian is not None
        else torch.sum(concat_tensor, dim=0)
    )