Calling loss.backward() reduce memory usage?

It’s because of the scoping rules in Python. When you do this:

while True:
    loss = model(output)

it will always use 2x the memory that is needed to compute the model, because the reference to the loss form the previous iteration won’t be overwritten (and thus the graph with all the buffers it holds won’t be freed) until this iteration completes. So you’ll effectively end up holding to two graphs. This is why you should use volatile=True inputs when only doing inference. Once you add .backward() the buffers will be freed in the process of computing the derivatives, limiting the memory usage to that of a single graph (the old graph will be still kept around, but it won’t be holding to any memory).

This is also why del helps reduce memory usage. This loop will only keep at most a single graph alive, because the loss is created and disposed within a single iteration.

while True:
    loss = model(input)
    del loss # This frees the graph
15 Likes