Custom Autogrdd-Function does not set requires_grad

I want to write a custom autograd-function to save gpu-memory and recompute results during the backward-pass. The whole architecture is similiar to invertible-resnets.

A simplified example of what I want to write is:

class MemorySaving(torch.autograd.function):
    @staticmethod
    def forward(ctx, keep, pred, eval_function, combine_function, add_input, *weights):
        ctx.eval_function = eval_function
        ctx.combine_function = combine_function
        ctx.weights = weights
        with torch.no_grad():
            temp, add_out = eval_function(keep, add_input)
            res_pred, factor = combine_function.forward(temp, pred)
            pred.set_()
            del pred
        ctx.save_for_backward(keep.data, res_pred, add_input)
        return (res_pred, add_out, factor)

    @staticmethod
    def backward(ctx, res_pred_grad, add_out_grad, factor_grad):
        eval_function = ctx.eval_function
        combine_function = ctx.combine_function
        weights = ctx.weights
        keep, res_pred, add_input = ctx.saved_tensors
        with torch.enable_grad():
            keep.requires_grad = True
            temp, add_out = eval_function(keep, add_input)
            pred, _ = combine_function.reverse(temp, res_pred)
            pred = pred.detach()
            pred.requires_grad = True
            resulting_pred, factor = combine_function.forward(temp, pred)
            grad_pred = torch.autograd.grad(resulting_pred, (keep, add_input, pred) + weights, res_pred_grad)
            grad_factor = torch.autograd.grad(factor, (keep, add_input) + weights, factor_grad)
            grad_add_out = torch.autograd.grad(add_out, (keep, add_input) + weights, add_out_grad)

            #omitting code to combine the grads

        return (keep_grad, pred_grad, None, None, add_input) + weights_grads

I am omitting some code for the backward-function because this function is (currently) not the problem.

My problem is that I have some code like this:

if memory_saving:
    res_pred, add_out, factor = MemorySaving.apply(keep, pred, eval_function, combine_function, add_input, *weights)
else:
    temp, add_out = eval_function(keep, add_input)
    res_pred, factor = combine_function.forward(temp, pred)

and res_pred, add_out, factor all have requires_grad set to False and no grad_fn if memory_saving is True. What am I doing wrong? I tried adding a res_pred, add_out, factor = res_pred.clone(), add_out.clone(), factor.clone() to torch.autograd.Function.forward but it does not change anything.

Hi,

Is there a reason why you’re not using the checkpoint tool? It looks like exactly what you’re trying to reimplement.

For your code, the forward pass in a Function already runs in a torch.no_grad() block. So no need to add another one.
The fact that you return a tuple might be a problem. Can you just return your three tensors like return res_pred, add_out, factor ?

If I combine them correctly I get constant memory consumption independent of depth, which is different from checkpointing (I only need the resulting tensors to compute the input gradients, assuming add_out is None).

The fact that you return a tuple might be a problem. Can you just return your three tensors like return res_pred, add_out, factor ?

return res_pred, add_out, factor does not make any difference.

Ok, maybe I’m missing something. It’s just that your implementation is very close to the checkpoint one in the sense that you save the input, and redo a forward/backward during the backward pass.

What are the types of the inputs to your function? Do you have at least one Tensor that requires gradients there?

This was the problem, i accidentially called it with self.parameters() and not *[p for p in self.self.parameters()]

1 Like