Re-forwarding in `backward` cause higher GPU memory usage than expected

In PyTorch, I have used saved_for_backward to save the input tensors of certain layers before running the forward of those layers in torch.no_grad mode. My point is to recompute the intermediate data during the backward of torch.autograd.Function without any memory leak and predict the GPU memory usage during backward in a layer-by-layer manner. However, I observed that the GPU memory usage by torch.cuda.memory_allcoated is much higher than I expected. What should be the cause? Or maybe someone could introduce me to the “formal” way of getting the GPU memory usage during the backward of torch.autograd.Function in a layer-by-layer manner? I also tried using APIs like register_full_backward_hook but it made the situation even worse in that the result is much higher than my current method. (according to a recent issue from GitHub Issues of PyTorch, it seems that PyTorch memory leaks when self is used inside these hooks.)

class CheckpointFunction(torch.autograd.Function):
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.length = length
        with torch.no_grad():
            output_tensors = ctx.run_function(*args[:length])
        return output_tensors

    def backward(ctx, *output_grads):
        input_tensors = ctx.saved_tensors[:ctx.length]
        input_params = list(ctx.saved_tensors[ctx.length:])
        input_tensors = [x.detach().requires_grad_(True) for x in input_tensors]
        with torch.enable_grad():
            output_tensors = ctx.run_function(*input_tensors)

        input_grads = torch.autograd.grad(
            input_tensors + input_params,
        input_grads = (None,)*ctx.length + input_grads[ctx.length:]

        del output_tensors
        if ctx.run_function.is_enabled_logging():
                '## VGG, segment re-forward, torch.cuda.memory_allocated(): {}MB'

        return (None, None) + input_grads

It is a bit difficult to understand the problem without a runnable example that illustrates how you are not getting the expected result. Is the problem that the observed memory is higher than expected?

I could not tell how many/what types of layers are being checkpointed by your function, but note that if you are only checkpointing a single layer, then there would not be any expected memory savings as you have already saved all of the “intermediate” results via save_for_backward.