Diagonal Hessian wrt parameters when using mini-batches

I am interested in computing the diagonal values of the Hessian of the loss function with respect to the parameters of an nn.Module to use for statistical inference (standard errors depend on the Hessian). The complication is that the model is estimated using mini-batches, because the full data doesn’t fit into memory, and existing solutions I’ve come across only give you the Hessian for the loss evaluated on a single batch of data.

Does anyone have ideas for how to compute the Hessian evaluated on the full dataset? For a single batch, a (not safe) method like the code below works. My current thinking is that I can just compute the Hessian for a loss function with `sum’ reduction on each batch separately and then sum across all batches, but I haven’t yet convinced myself this is valid so figured I’d see if someone else had thoughts

_input=torch.randn(32,3)
layer = nn.Linear(3,4)
criterion=nn.CrossEntropyLoss()

weight = layer.weight

def func(weight):
    del layer.weight
    layer.weight=weight
    return criterion(layer(_input), torch.zeros(len(_input),dtype=torch.long))

torch.autograd.functional.hessian(func,weight)

Related links:

Hi there, I see that you are using my silly hack solution haha:)
Please take a look at NN Module functional API · Issue #49171 · pytorch/pytorch · GitHub
which shall provide an elegant solution after 1.11 release.

Neat, thanks @Ren_Pang. Don’t suppose you have a sense of when 1.11 might drop? Can’t imagine it’s too soon given we’re on 1.9.1

Thanks again!

They just finished the rc2 build for 1.10, which means it will release days later.
I assume 1.11 will take more than 4 months after 1.10 release.

This feature is implemented 1 or 2 weeks ago, so it missed the 1.10 branch cut. And it’s not a feature in urgent demand so it’s not included in the recent cherry picks.