Checkpointing does not work well with double backward

Here is a working code which computes Hessian-Vector product using double backward.

However, when using the checkpoint feature (set use_checkpoint=True), one gets an error. This is somewhat expected since the checkpoint function does not construct a graph for the gradient. Sadly, a more fundamental issue is that even if one adds create_graph=True to https://github.com/pytorch/pytorch/blob/5d3a347685d826867973fd5f36d6f4c99f6b544b/torch/utils/checkpoint.py#L95, the returned result is still wrong (all zero). This is due to that the checkpoint function has its local detached variables.

Is there a way to get the correct results while still using the checkpoint function?
This is a highly desired feature when the forward function which we compute the Hessian is memory expensive. Thank you!

import torch 
from torch.utils.checkpoint import checkpoint

def square(*args):
    x, = args
    return x**2/2

def forward(x, use_checkpoint):
    args = x, 
    if use_checkpoint:
        x = checkpoint(square, *args)
    else:
        x = square(*args)
    return x.sum()

def hessp(x, p, use_checkpoint):
    loss1 = forward(x, use_checkpoint)
    loss1.backward(create_graph=True)
    loss2 = x.grad.view(-1)@p
    x.grad.data.zero_()
    loss2.backward()
    return x.grad

if __name__=='__main__':
    torch.manual_seed(42)
    p = torch.randn(4)
    x = torch.randn(4).requires_grad_()
    hp = hessp(x, p, use_checkpoint=False)

    print (p)
    print (hp)