I am interested in computing the diagonal values of the Hessian of the loss function with respect to the parameters of an nn.Module
to use for statistical inference (standard errors depend on the Hessian). The complication is that the model is estimated using mini-batches, because the full data doesn’t fit into memory, and existing solutions I’ve come across only give you the Hessian for the loss evaluated on a single batch of data.
Does anyone have ideas for how to compute the Hessian evaluated on the full dataset? For a single batch, a (not safe) method like the code below works. My current thinking is that I can just compute the Hessian for a loss function with `sum’ reduction on each batch separately and then sum across all batches, but I haven’t yet convinced myself this is valid so figured I’d see if someone else had thoughts
_input=torch.randn(32,3)
layer = nn.Linear(3,4)
criterion=nn.CrossEntropyLoss()
weight = layer.weight
def func(weight):
del layer.weight
layer.weight=weight
return criterion(layer(_input), torch.zeros(len(_input),dtype=torch.long))
torch.autograd.functional.hessian(func,weight)
Related links: