Here is a working code which computes Hessian-Vector product using double backward.
However, when using the checkpoint feature (set use_checkpoint=True
), one gets an error. This is somewhat expected since the checkpoint function does not construct a graph for the gradient. Sadly, a more fundamental issue is that even if one adds create_graph=True
to https://github.com/pytorch/pytorch/blob/5d3a347685d826867973fd5f36d6f4c99f6b544b/torch/utils/checkpoint.py#L95, the returned result is still wrong (all zero). This is due to that the checkpoint function has its local detached variables.
Is there a way to get the correct results while still using the checkpoint function?
This is a highly desired feature when the forward function which we compute the Hessian is memory expensive. Thank you!
import torch
from torch.utils.checkpoint import checkpoint
def square(*args):
x, = args
return x**2/2
def forward(x, use_checkpoint):
args = x,
if use_checkpoint:
x = checkpoint(square, *args)
else:
x = square(*args)
return x.sum()
def hessp(x, p, use_checkpoint):
loss1 = forward(x, use_checkpoint)
loss1.backward(create_graph=True)
loss2 = x.grad.view(-1)@p
x.grad.data.zero_()
loss2.backward()
return x.grad
if __name__=='__main__':
torch.manual_seed(42)
p = torch.randn(4)
x = torch.randn(4).requires_grad_()
hp = hessp(x, p, use_checkpoint=False)
print (p)
print (hp)