It seems that calling loss.backward() help save memory, which is not very intuitive to me.
import sys
import torch
import torch.utils
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models
try:
import gpustat
except ImportError:
raise ImportError("pip install gpustat")
def show_memusage(device=0):
gpu_stats = gpustat.GPUStatCollection.new_query()
item = gpu_stats.jsonify()["gpus"][device]
print("{}/{}".format(item["memory.used"], item["memory.total"]))
device = 0
show_memusage(device=device)
torch.cuda.set_device(device)
model = models.resnet101(pretrained=False)
model.cuda()
criterion = nn.CrossEntropyLoss()
volatile = False
show_memusage(device=device)
for ii in range(3):
inputs = torch.randn(20, 3, 224, 224)
labels = torch.LongTensor(range(20))
inputs = inputs.cuda()
labels = labels.cuda()
inputs = Variable(inputs, volatile=volatile)
labels = Variable(labels, volatile=volatile)
print "before run model:",
show_memusage(device=device)
outputs = model(inputs)
print "after run model:",
show_memusage(device=device)
if bool(int(sys.argv[1])):
loss = criterion(outputs, labels)
print "before backward:",
show_memusage(device=device)
loss.backward()
print "after backward:",
show_memusage(device=device)
print
running “python test_dsetloadermem.py 0” gives
mem used / mem total
612/8110
1045/8110
before run model: 1055/8110
after run model: 3699/8110
before run model: 3711/8110
after run model: 6215/8110
before run model: 6215/8110
after run model: 6215/8110
running “python test_dsetloadermem.py 1” gives
mem used / mem total
618/8110
1051/8110
before run model: 1063/8110
after run model: 3699/8110
before backward: 3699/8110
after backward: 4131/8110
before run model: 4131/8110
after run model: 4131/8110
before backward: 4131/8110
after backward: 4131/8110
before run model: 4131/8110
after run model: 4131/8110
before backward: 4131/8110
after backward: 4131/8110
I’ve also noticed that
del outputs
del inputs
del labels
can help avoid unnecessary memory usage.
Could you give any insight on why loss.backward() does something similar?
Thanks!