Calling loss.backward() reduce memory usage?

The forward pass will create the computation graph, which is used in the backward pass to calculate the gradients.
loss is attached to this graph and the graph won’t be freed until loss is deleted of goes out of scope.
Of course the computation graph holds references to intermediate tensors, which are also needed to compute the gradients, and will thus use memory.

Given this simple loop:

while True:
    loss = model(output)

These steps will happen:

  • model and output are created and thus will use memory
  • loss = model(output) represents the first forward pass. The computation graph as well as the intermediate tensors will be created. Memory usage increases (model + output + loss0 + intermediates0)
  • the iteration is done and the next one will start. Current memory usage is still (model + output + loss0 + intermediates0)
  • the next iteration will start and another forward call will be kicked off. Memory usage during the forward pass and before the assignment of loss: (model + output + intermediates1 + loss0 + intermediates0) Note that we end up creating two graphs at this point.
  • the new loss will be assigned to the variable loss. The old loss (loss0) will be freed and thus also intermediates0 and the corresponding graph
4 Likes