# Computing Hessian diagonals efficiently with torch.func

I’m looking for ways to compute the diagonal of a model’s parameter Hessian with regard to the model’s input loss efficiently, ideally with `torch.func` and its `vmap` capabilities.
Due to memory constraints, I want to avoid computing the full Hessian and instead opt for an implementation using a Hessian vector product.

So far, I’ve been able to correctly compute the diagonal values, but the computation is incredibly slow. Any ideas how to make this computationally faster?

``````import torch

def f(model, params, input, target):
predictions = torch.func.functional_call(model, params, (input,))

def hvp(f, model, params, input, target, v):
lambda params: torch.dot(
torch.cat(
[
p.flatten()
model, params, input, target
).values()
]
),
v,
)
)(params)

model = torch.nn.Linear(2, 2)
params = dict(model.named_parameters())

x = torch.randn(2)
y = torch.tensor(1)

num_params = sum(p.numel() for p in params.values())
hessian_diag = None
for p_i in range(num_params):
v = torch.zeros(num_params)
v[p_i] = 1
hvp_dict = hvp(f, model, params, x, y, v)
hessian_diag = (
torch.cat(
[
hessian_diag,
torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
p_i
].unsqueeze(0),
],
dim=0,
)
if hessian_diag is not None
else torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
p_i
].unsqueeze(0)
)
``````

I tried to vectorise the computation of the dot product with no luck. That would definitely speed up the computation, however, I’m open to completely different ideas for this as well.