I’m trying to understand the behavior of backward() in the following example:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
def status( label ):
print( f"{label}: ( {torch.cuda.memory_allocated() // 1024 ** 2:,} MiB, {torch.cuda.max_memory_allocated() // 1024 ** 2:,} MiB )" )
x = torch.randn( 1, 1, 1024, 1024 ).cuda()
status( "x" )
ncs = [ 1, 256, 128, 64, 32, 16, 8, 4, 2, 1 ]
layers = [ nn.Conv2d( ncs[ i - 1 ], ncs[ i ], 3, 1, 1 ).cuda() for i in range( 1, len( ncs ) ) ]
status( "x + layers" )
def doit( x, layers, start, end, dummy = None ):
for i in range( start, end ):
x = layers[ i ]( x )
return x
# res = x
# segs = [ 0, 2, 5, len( layers ) ]
#
# for i in range( 1, len( segs ) ):
# res = checkpoint( doit, res, layers, segs[ i - 1 ], segs[ i ], layers[ 0 ].weight )
#
# res = res.mean()
res = doit( x, layers, 0, len( layers ) ).mean()
status( "x + layer + res" )
del x
res.backward()
status( "x + layer + res + backward" )
The result I get:
x: ( 4 MiB, 4 MiB )
x + layers: ( 5 MiB, 5 MiB )
x + layer + res: ( 2,049 MiB, 2,049 MiB )
x + layer + res + backward: ( 3 MiB, 6,027 MiB )
I have difficulties understanding why backward() triples the memory load. I’d expect a gradient of a layer to be as big its the input. So, it shouldn’ add more than 2 GiB. But this is already an overestimation, since not all gradients have to be stored, but only the most recent one while moving backward. So the additional memory should be roughly the size of two layer activations. And even that is an overestimation, since already processed activations can be freed. So all in all, I would expect ( for this example ) a backward memory load just slightly higher than the forward load. Can someone explain this unexpectedly high memory need?