CUDA memory leakage

Hi, all

I recently ran into a problem with cuda memory leakage. 8 GPUs ran out of their 12GB of memory after a certain number of training steps. And I noticed that the GPU memory usage was stacking up gradually. There were about 40MB of memory usage per GPU increased every step, after forcing an update on os using torch.cuda.empty_cache(). Since my training code is fairly simple, I suspect there is something fishy going on in those encapsulated modules. Has anyone encountered a similar problem before?

Following is the main body of the training code

for epoch in range(nepoch):
        for im, rois_h, rois_o, scores, ip, labels in dataloader:

            # relocate tensors to cuda
            im = torch.cat([im] * nGPUs, 0).to(device)
            rois_h = rois_h.to(device)
            rois_o = rois_o.to(device)
            scores = scores.to(device)
            ip = ip.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()
            # perform forward pass
            out = net(im, rois_h, rois_o, scores, ip)

            # compute loss
            loss = criterion(out, labels.float())
            # perform back propogation
            loss.backward()
            optimizer.step()
            # clean up cache
            torch.cuda.empty_cache()

Are you saving some tensors somewhere (e.g. the loss in a list)?
Often memory leaks are created by trying to store some training information like the loss without detaching it from the computation graph, which will store the whole graph with it.

3 Likes

Not at all. There are very few lines of code where I access CUDA.

The first part is where I move the model to CUDA

net = Net()
net = torch.nn.DataParallel(net)
net.to(device)

What I showed in the code snippet at the top is sort of the second part, where I get the minibatch, do forward pass and back propagation etc. What I didn’t show, though, is a just couple of lines to take snapshot of the model and also print the loss directly to a file, as below.

# save the model at specified steps
if step % snapshot == 0 and step >= snapshot:
    file_name = 'ho-rcnn-model-step' + step
    torch.save(net.state_dict(), os.path.join(path, file_name))

# print statistics
if step % printstats == 0:
    with open(path_stats, 'ab') as fid:
        fid.write('\n' + str(datetime.datetime.now()))
        fid.write('\nStep: {}, loss of current batch is {}\n'.format(step, loss))

Apart from that, there isn’t anymore usage of CUDA, or reference to any variables stored in CUDA, which is why I think the problem might be caused within some of the encapsulated torch modules

Thanks for the info!
Could you change the last line of code to loss.item() instead of loss and see if the memory leak disappears?

Hi,
I have experienced similar, but in my case it is independent of cuda.
https://discuss.pytorch.org/t/memory-leak-during-backprop-in-reinforcement-learning-tutorial/33968

Unfortunately, that is not the problem.

But I did a bit of memory usage tracking, by stepping through the main body of the training code, with all logging excluded, in a python shell. I found out that the significant memory usage increase happened at two lines of code, forward pass and backprop.

out = net(im, rois_h, rois_o, scores, ip)
...
loss.backward()

But the memory released after those two lines is always about 30MB smaller than the increase. So now it’s pretty clear that the memory leak happens either in my model, or the backprop. module. Given that no one has really complained about the backward() method, it’s highly likely I’ve done something wrong in the model forward pass.

Please take a look at the code

def forward(self, im, bh, bo, s, i):
# inputs are formatted in accordance to the three streams and the single neuron path
	"""
	Forward pass of the model. Note that 

	Inputs:
		im - input images
		bh - the human bounding box
		ho - the object bounding box
		s - classification score obtained from object detector
		i - the interaction pattern for an ho pair

	"""
	# compute features
	x = self.get_features(im)

	# human stream
	h = self.roipool(x, bh)
	h = h.view(-1, 256 * 6 * 6)
	h = F.dropout(F.relu(self.fc6_h(h)))
	h = self.fc8_h(F.dropout(F.relu(self.fc7_h(h))))

	# object stream
	o = self.roipool(x, bo)
	o = o.view(-1, 256 * 6 * 6)
	o = F.dropout(F.relu(self.fc6_o(o)))
	o = self.fc8_o(F.dropout(F.relu(self.fc7_o(o))))

	# pairwise stream
	i = self.maxpool(F.relu(self.conv1_p(i)))
	i = self.maxpool(F.relu(self.conv2_p(i)))
	i = i.view(-1, 32 * 12 * 12)
	i = self.fc4_p(F.relu(self.fc3_p(i)))
	
	# single neuron path
	s = self.fc2_s(self.fc1_s(s))
	
	# perform element-wise addition
	return (h + o + i + s)

def get_features(self, im):
	x = F.local_response_norm(self.maxpool(F.relu(self.conv1(im))), 5)
	x = F.local_response_norm(self.maxpool(F.relu(self.conv2(x))), 5)
	x = F.relu(self.conv3(x))
	x = F.relu(self.conv4(x))
	x = F.relu(self.conv5(x))
	return x

Thanks for debugging!
Your code looks alright. At least I cannot find any obvious issues skimming through your code.
Could you additionally post the code for roipool?

I used a third-party implementation from Github for ROI pooling, since there doesn’t seem to be a native pytorch module in that regard. The module is instantiated as below

from roi_pooling.modules.roi_pool import RoIPool
...
self.roipool = RoIPool(6, 6, 1.0 / 16)

NEW UPDATE:

I’ve finally found the source of leak. It does have something to do with the roi pooling module.

h = self.roipool(x, bh) The variable h here was still stored in the memory even though it was overwritten afterwards. And the behaviour repeats itself in every single step, which was causing the snowballing of memory leak.

Also, the variable bh here, which consists of a bunch of bounding boxes, was also not released after the forward pass, same for any subsequent steps. It is essentially these two variables, that were causing the memory leak.

For the variables o in o = self.roipool(x, bo) and bo in the “object stream”, same problem applies.

A possible reason for the not released memory is that these bounding boxes, including variable bh and bo, have variable length across different batches. I’ve heard that pytorch seems to have some problems with variable length batch data. But I’m not sure if this is indeed causing the memory leak.

THE PROBLEM IS SOLVED

The memory leak was caused by a python wrapper of the roi pooling function

The attribute self.output contains the output of the roi pooling module. But since the output will be returned anyway, there really isn’t any point saving it as a private attribute of module. And as a consequence, the memory taken by self.output was somehow never released and stacked up over training steps.

After commenting out this attribute, the memory leak was gone.

1 Like

Awesome you’ve found the reason for this issue.
CC @longcw in case that’s interesting or a misunderstanding.

Actually I met this issue a long time ago, the solution is to create a new Function object every time you use it, as I did in the module. But as you mentioned, you can comment self.output and it seems to be useless in the function.

Em…this is a bit odd, since I am indeed using the module, instead of calling the function directly. And yes, most of the time the solution seems to be working. I didn’t find any problems with it when I was testing the module alone either.

The abnormal memory leak happens when I used the module in a fairly large network. Maybe it has something to do with python, or even pytorch. But still, I would recommend commenting out the relevant lines since self.output doesn’t have further reference after all.

Thanks for your attention anyway!

@longcw @ptrblck @DzReal Hi, may I ask any idea why self.output cause the leak, I mean under the hood?

1 Like

This is great! Thank you.