The forward pass will create the computation graph, which is used in the backward pass to calculate the gradients.
loss
is attached to this graph and the graph won’t be freed until loss
is deleted of goes out of scope.
Of course the computation graph holds references to intermediate tensors, which are also needed to compute the gradients, and will thus use memory.
Given this simple loop:
while True:
loss = model(output)
These steps will happen:
-
model
andoutput
are created and thus will use memory -
loss = model(output)
represents the first forward pass. The computation graph as well as the intermediate tensors will be created. Memory usage increases (model + output + loss0 + intermediates0) - the iteration is done and the next one will start. Current memory usage is still (model + output + loss0 + intermediates0)
- the next iteration will start and another forward call will be kicked off. Memory usage during the forward pass and before the assignment of
loss
: (model + output + intermediates1 + loss0 + intermediates0) Note that we end up creating two graphs at this point. - the new
loss
will be assigned to the variableloss
. The oldloss
(loss0) will be freed and thus also intermediates0 and the corresponding graph