I’m looking for ways to compute the diagonal of a model’s parameter Hessian with regard to the model’s input loss efficiently, ideally with torch.func
and its vmap
capabilities.
Due to memory constraints, I want to avoid computing the full Hessian and instead opt for an implementation using a Hessian vector product.
So far, I’ve been able to correctly compute the diagonal values, but the computation is incredibly slow. Any ideas how to make this computationally faster?
import torch
def f(model, params, input, target):
predictions = torch.func.functional_call(model, params, (input,))
return torch.nn.functional.cross_entropy(predictions, target)
def hvp(f, model, params, input, target, v):
return torch.func.grad(
lambda params: torch.dot(
torch.cat(
[
p.flatten()
for p in torch.func.grad(f, argnums=1)(
model, params, input, target
).values()
]
),
v,
)
)(params)
model = torch.nn.Linear(2, 2)
params = dict(model.named_parameters())
x = torch.randn(2)
y = torch.tensor(1)
num_params = sum(p.numel() for p in params.values())
hessian_diag = None
for p_i in range(num_params):
v = torch.zeros(num_params)
v[p_i] = 1
hvp_dict = hvp(f, model, params, x, y, v)
hessian_diag = (
torch.cat(
[
hessian_diag,
torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
p_i
].unsqueeze(0),
],
dim=0,
)
if hessian_diag is not None
else torch.cat([hvp_dict[param].flatten() for param in hvp_dict])[
p_i
].unsqueeze(0)
)
I tried to vectorise the computation of the dot product with no luck. That would definitely speed up the computation, however, I’m open to completely different ideas for this as well.