Is it possible to calculate the Hessian of a network while using gradient checkpointing?

Hi All,

I just have a general question about the use of gradient checkpointing! I’ve recently discussed this method and it seems it’d be quite useful for my current research as I’m running out of CUDA memory.

After reading the docs, it looks like it doesn’t support the use of torch.autograd.grad() but only torch.autograd.backward. Within my model, I used both torch.autograd.grad and torch.autograd.backward as my loss function depends on the Laplacian (trace of the Hessian) of the network with respect to the inputs and another terms. This would be something like this (in terms of actual code),

output = model(input) #R^N to R^1 function (nn.Module)
loss_values = calc_laplacian(model, input) + input.pow(2) #M number of samples  in batch
loss = loss_values.mean() #take mean of loss_values to get loss for given batch

optim.zero_grad() #clear grad cache
loss.backward()   #calculate loss
optim.step()      #update parameters

where calc_laplacian is my own custom function which calculates the Laplacian (via use of torch.autograd.grad) of an nn.Module function with respect to input which is the input Tensor to the model.

Would it be possible to implement gradient checkpointing for such a model, and if so, are there any examples available online?

Any help is appreciated! Thank you! :slight_smile: