In order to save memory, I would like to only save Tensors for backward when backward might be called. For first derivatives I can do this by checking whether the input has requires_grad=True
. For second derivatives it seems that the best I could do would be to check whether create_graph == True
. Is there a way to check in the forward
method of an autograd.Function
whether create_graph == True
?
class SquareForward(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
if x.requires_grad: # Only save when backward might be called
ctx.save_for_backward(x)
return x**2
@staticmethod
def backward(ctx, grad_y):
x, = ctx.saved_tensors
return SquareBackward.apply(grad_y, x)
class SquareBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_y, x):
# >>> Want to only save when double backward might be called <<<
ctx.save_for_backward(grad_y, x)
return 2 * grad_y * x
@staticmethod
def backward(ctx, grad_grad_x):
grad_y, x = ctx.saved_tensors
return grad_grad_x * 2 * x, grad_grad_x * 2 * grad_y
I would like to be able to do something like if grad_y.requires_grad
or if ctx.create_graph
to determine whether to save Tensors in the SquareBackward
forward
method. This is especially important when the Tensors that I would like to save are a large number of intermediate results that I would not otherwise allocate memory to store simultaneously. The only solution at the moment seems to be to never store the intermediate results in the SquareBackward
forward
method and to instead recompute these intermediate results in the SquareBackward
backward
method when we are sure that double backward is being performed, but that recomputation would be wasteful.