Computing Hessian diagonals efficiently with torch.func

I’m looking for ways to compute the diagonal of a model’s parameter Hessian with regard to the model’s input loss efficiently, ideally with torch.func and its vmap capabilities.
Due to memory constraints, I want to avoid computing the full Hessian and instead opt for an implementation using a Hessian vector product.

So far, I’ve been able to correctly compute the diagonal values, but the computation is incredibly slow. Any ideas how to make this computationally faster?

import torch


def f(model, params, input, target):
    predictions = torch.func.functional_call(model, params, (input,))
    return torch.nn.functional.cross_entropy(predictions, target)


def hvp(f, model, params, input, target, v):
    return torch.func.grad(
        lambda params: torch.dot(
            torch.cat(
                [
                    p.flatten()
                    for p in torch.func.grad(f, argnums=1)(
                        model, params, input, target
                    ).values()
                ]
            ),
            v,
        )
    )(params)


model = torch.nn.Linear(2, 2)
params = dict(model.named_parameters())

x = torch.randn(2)
y = torch.tensor(1)

num_params = sum(p.numel() for p in params.values())
hessian_diag = None
for p_i in range(num_params):
    v = torch.zeros(num_params)
    v[p_i] = 1
    hvp_dict = hvp(f, model, params, x, y, v)
    hessian_diag = (
        torch.cat(
            [
                hessian_diag,
                torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
                    p_i
                ].unsqueeze(0),
            ],
            dim=0,
        )
        if hessian_diag is not None
        else torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
            p_i
        ].unsqueeze(0)
    )

I tried to vectorise the computation of the dot product with no luck. That would definitely speed up the computation, however, I’m open to completely different ideas for this as well.