Why does backward() require 3x the memory of forward()

I’m trying to understand the behavior of backward() in the following example:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

def status( label ):

    print( f"{label}: ( {torch.cuda.memory_allocated() // 1024 ** 2:,} MiB, {torch.cuda.max_memory_allocated() // 1024 ** 2:,} MiB )" )


x = torch.randn( 1, 1, 1024, 1024 ).cuda()

status( "x" )

ncs    = [ 1, 256, 128, 64, 32, 16, 8, 4, 2, 1 ]
layers = [ nn.Conv2d( ncs[ i - 1 ], ncs[ i ], 3, 1, 1 ).cuda() for i in range( 1, len( ncs ) ) ]

status( "x + layers" )

def doit( x, layers, start, end, dummy = None ):
    for i in range( start, end ):
        x = layers[ i ]( x )
    return x

# res  = x
# segs = [ 0, 2, 5, len( layers ) ]
#
# for i in range( 1, len( segs ) ):
#     res = checkpoint( doit, res, layers, segs[ i - 1 ], segs[ i ], layers[ 0 ].weight )
#
# res = res.mean()

res = doit( x, layers, 0, len( layers ) ).mean()

status( "x + layer + res" )

del x

res.backward()

status( "x + layer + res + backward" )

The result I get:

x: ( 4 MiB, 4 MiB )
x + layers: ( 5 MiB, 5 MiB )
x + layer + res: ( 2,049 MiB, 2,049 MiB )
x + layer + res + backward: ( 3 MiB, 6,027 MiB )

I have difficulties understanding why backward() triples the memory load. I’d expect a gradient of a layer to be as big its the input. So, it shouldn’ add more than 2 GiB. But this is already an overestimation, since not all gradients have to be stored, but only the most recent one while moving backward. So the additional memory should be roughly the size of two layer activations. And even that is an overestimation, since already processed activations can be freed. So all in all, I would expect ( for this example ) a backward memory load just slightly higher than the forward load. Can someone explain this unexpectedly high memory need?

Unfortunately just inspecting the peak memory usage via torch.cuda.max_memory_allocated() here isn’t a reliable way to deduce the memory used only by tensors. It looks like in this case there is a backward cuDNN kernel that allocates a temporary workspace which would increase what is reported by torch.cuda.max_memory_allocated(). In general, cuDNN will opportunistically allocate temporary workspaces in order to run faster kernels (that require these temporary workspaces) if there is enough memory available on the device.

It’s a bit crude, but you can force less memory to be used this way by setting torch.cuda.set_per_process_memory_fraction(value) in your script. On a 32GiB GPU, I set the value to 0.08 and saw x + layer + res + backward: ( 3 MiB, 2,569 MiB ).