I’m looking for ways to compute the diagonal of a model’s parameter Hessian with regard to the model’s input loss efficiently, ideally with `torch.func`

and its `vmap`

capabilities.

Due to memory constraints, I want to avoid computing the full Hessian and instead opt for an implementation using a Hessian vector product.

So far, I’ve been able to correctly compute the diagonal values, but the computation is incredibly slow. Any ideas how to make this computationally faster?

```
import torch
def f(model, params, input, target):
predictions = torch.func.functional_call(model, params, (input,))
return torch.nn.functional.cross_entropy(predictions, target)
def hvp(f, model, params, input, target, v):
return torch.func.grad(
lambda params: torch.dot(
torch.cat(
[
p.flatten()
for p in torch.func.grad(f, argnums=1)(
model, params, input, target
).values()
]
),
v,
)
)(params)
model = torch.nn.Linear(2, 2)
params = dict(model.named_parameters())
x = torch.randn(2)
y = torch.tensor(1)
num_params = sum(p.numel() for p in params.values())
hessian_diag = None
for p_i in range(num_params):
v = torch.zeros(num_params)
v[p_i] = 1
hvp_dict = hvp(f, model, params, x, y, v)
hessian_diag = (
torch.cat(
[
hessian_diag,
torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
p_i
].unsqueeze(0),
],
dim=0,
)
if hessian_diag is not None
else torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
p_i
].unsqueeze(0)
)
```

I tried to vectorise the computation of the dot product with no luck. That would definitely speed up the computation, however, I’m open to completely different ideas for this as well.