I want to write a custom autograd-function to save gpu-memory and recompute results during the backward-pass. The whole architecture is similiar to invertible-resnets.
A simplified example of what I want to write is:
class MemorySaving(torch.autograd.function):
@staticmethod
def forward(ctx, keep, pred, eval_function, combine_function, add_input, *weights):
ctx.eval_function = eval_function
ctx.combine_function = combine_function
ctx.weights = weights
with torch.no_grad():
temp, add_out = eval_function(keep, add_input)
res_pred, factor = combine_function.forward(temp, pred)
pred.set_()
del pred
ctx.save_for_backward(keep.data, res_pred, add_input)
return (res_pred, add_out, factor)
@staticmethod
def backward(ctx, res_pred_grad, add_out_grad, factor_grad):
eval_function = ctx.eval_function
combine_function = ctx.combine_function
weights = ctx.weights
keep, res_pred, add_input = ctx.saved_tensors
with torch.enable_grad():
keep.requires_grad = True
temp, add_out = eval_function(keep, add_input)
pred, _ = combine_function.reverse(temp, res_pred)
pred = pred.detach()
pred.requires_grad = True
resulting_pred, factor = combine_function.forward(temp, pred)
grad_pred = torch.autograd.grad(resulting_pred, (keep, add_input, pred) + weights, res_pred_grad)
grad_factor = torch.autograd.grad(factor, (keep, add_input) + weights, factor_grad)
grad_add_out = torch.autograd.grad(add_out, (keep, add_input) + weights, add_out_grad)
#omitting code to combine the grads
return (keep_grad, pred_grad, None, None, add_input) + weights_grads
I am omitting some code for the backward-function because this function is (currently) not the problem.
My problem is that I have some code like this:
if memory_saving:
res_pred, add_out, factor = MemorySaving.apply(keep, pred, eval_function, combine_function, add_input, *weights)
else:
temp, add_out = eval_function(keep, add_input)
res_pred, factor = combine_function.forward(temp, pred)
and res_pred, add_out, factor
all have requires_grad
set to False
and no grad_fn
if memory_saving
is True
. What am I doing wrong? I tried adding a res_pred, add_out, factor = res_pred.clone(), add_out.clone(), factor.clone()
to torch.autograd.Function.forward
but it does not change anything.