Checkpointing does not work well with double backward

Here is a working code which computes Hessian-Vector product using double backward.

However, when using the checkpoint feature (set use_checkpoint=True), one gets an error. This is somewhat expected since the checkpoint function does not construct a graph for the gradient. Sadly, a more fundamental issue is that even if one adds create_graph=True to, the returned result is still wrong (all zero). This is due to that the checkpoint function has its local detached variables.

Is there a way to get the correct results while still using the checkpoint function?
This is a highly desired feature when the forward function which we compute the Hessian is memory expensive. Thank you!

import torch 
from torch.utils.checkpoint import checkpoint

def square(*args):
    x, = args
    return x**2/2

def forward(x, use_checkpoint):
    args = x, 
    if use_checkpoint:
        x = checkpoint(square, *args)
        x = square(*args)
    return x.sum()

def hessp(x, p, use_checkpoint):
    loss1 = forward(x, use_checkpoint)
    loss2 = x.grad.view(-1)@p
    return x.grad

if __name__=='__main__':
    p = torch.randn(4)
    x = torch.randn(4).requires_grad_()
    hp = hessp(x, p, use_checkpoint=False)

    print (p)
    print (hp)