Calling loss.backward() reduce memory usage?

It seems that calling loss.backward() help save memory, which is not very intuitive to me.

import sys
import torch
import torch.utils
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models

try:
    import gpustat
except ImportError:
    raise ImportError("pip install gpustat")


def show_memusage(device=0):
    gpu_stats = gpustat.GPUStatCollection.new_query()
    item = gpu_stats.jsonify()["gpus"][device]
    print("{}/{}".format(item["memory.used"], item["memory.total"]))


device = 0
show_memusage(device=device)

torch.cuda.set_device(device)
model = models.resnet101(pretrained=False)
model.cuda()
criterion = nn.CrossEntropyLoss()

volatile = False

show_memusage(device=device)
for ii in range(3):
    inputs = torch.randn(20, 3, 224, 224)
    labels = torch.LongTensor(range(20))
    inputs = inputs.cuda()
    labels = labels.cuda()
    inputs = Variable(inputs, volatile=volatile)
    labels = Variable(labels, volatile=volatile)

    print "before run model:",
    show_memusage(device=device)
    outputs = model(inputs)
    print "after run model:",
    show_memusage(device=device)

    if bool(int(sys.argv[1])):
        loss = criterion(outputs, labels)
        print "before backward:",
        show_memusage(device=device)
        loss.backward()
        print "after backward:",
        show_memusage(device=device)
    print

running “python test_dsetloadermem.py 0” gives

mem used / mem total
612/8110
1045/8110
before run model: 1055/8110
after run model: 3699/8110

before run model: 3711/8110
after run model: 6215/8110

before run model: 6215/8110
after run model: 6215/8110

running “python test_dsetloadermem.py 1” gives
mem used / mem total
618/8110
1051/8110
before run model: 1063/8110
after run model: 3699/8110
before backward: 3699/8110
after backward: 4131/8110

before run model: 4131/8110
after run model: 4131/8110
before backward: 4131/8110
after backward: 4131/8110

before run model: 4131/8110
after run model: 4131/8110
before backward: 4131/8110
after backward: 4131/8110

I’ve also noticed that

del outputs
del inputs
del labels

can help avoid unnecessary memory usage.

Could you give any insight on why loss.backward() does something similar?
Thanks!

2 Likes

It’s because of the scoping rules in Python. When you do this:

while True:
    loss = model(output)

it will always use 2x the memory that is needed to compute the model, because the reference to the loss form the previous iteration won’t be overwritten (and thus the graph with all the buffers it holds won’t be freed) until this iteration completes. So you’ll effectively end up holding to two graphs. This is why you should use volatile=True inputs when only doing inference. Once you add .backward() the buffers will be freed in the process of computing the derivatives, limiting the memory usage to that of a single graph (the old graph will be still kept around, but it won’t be holding to any memory).

This is also why del helps reduce memory usage. This loop will only keep at most a single graph alive, because the loss is created and disposed within a single iteration.

while True:
    loss = model(input)
    del loss # This frees the graph
15 Likes

Thanks for the explanation! that helps a lot!

Hi @apaszke, still a little uncertain about the part where you say:

the reference to the loss from the previous iteration won’t be overwritten … until this iteration completes.

Can you elaborate on that piece a little more?

The forward pass will create the computation graph, which is used in the backward pass to calculate the gradients.
loss is attached to this graph and the graph won’t be freed until loss is deleted of goes out of scope.
Of course the computation graph holds references to intermediate tensors, which are also needed to compute the gradients, and will thus use memory.

Given this simple loop:

while True:
    loss = model(output)

These steps will happen:

  • model and output are created and thus will use memory
  • loss = model(output) represents the first forward pass. The computation graph as well as the intermediate tensors will be created. Memory usage increases (model + output + loss0 + intermediates0)
  • the iteration is done and the next one will start. Current memory usage is still (model + output + loss0 + intermediates0)
  • the next iteration will start and another forward call will be kicked off. Memory usage during the forward pass and before the assignment of loss: (model + output + intermediates1 + loss0 + intermediates0) Note that we end up creating two graphs at this point.
  • the new loss will be assigned to the variable loss. The old loss (loss0) will be freed and thus also intermediates0 and the corresponding graph
4 Likes