CUDA memory not being freed?

I’m noticing some weird behavior with memory not being freed from CUDA as it should be.
I can reproduce the following issue on two different machines:
Machine 1 runs Arch Linux and uses pytorch 0.3.1b0+2b47480 on python 2.7
Machine 2 runs Ubuntu 16.04 and uses pytorch 0.3.0.post4 on python 2.7

The simplest example I can do to replicate looks like this:



##########################################################################
#   FUNCTION BLOCK                                                       #
##########################################################################

# fxn taken from https://discuss.pytorch.org/t/memory-leaks-in-trans-conv/12492

def get_gpu_memory_map():   
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.used',
            '--format=csv,nounits,noheader'
        ])
    
    return float(result)


def memout_example():    
    assert vars() == {}
    # empty slate 

    # build persistent data 
    val_loader = cifar_loader.load_cifar_data('val', normalize=False, 
                                              batch_size=16, use_gpu=True)
    
    # now loop through batches and show that there's accumulating...
    for batch_no, (batch, labels) in enumerate(val_loader):
        
        # clean up garbage and clear cuda cache as much as possible 
        gc.collect()
        print "BATCH NUMBER: %s" % batch_no
        print "GPU MEMORY: %s" % get_gpu_memory_map()
        assert sorted(vars().keys()) == sorted(['labels', 'val_loader', 
                                                'batch', 'batch_no'])
        torch.cuda.empty_cache()
        
        
        # load things needed for attack 
        base_model = cifar_resnets.resnet32()
        adv_trained_net = checkpoints.load_state_dict_from_filename(
                                           'half_trained_madry.th', base_model)
        adv_trained_net.cuda()
        cifar_normer = utils.DifferentiableNormalize(mean=config.CIFAR10_MEANS,
                                           std=config.CIFAR10_STDS)        
        pgd_perceptual_loss = plf.PerceptualXentropy(adv_trained_net, 
                                          normalizer=cifar_normer, use_gpu=True)
        pgd_attack_obj = aa.LInfPGD(adv_trained_net, cifar_normer,
                                    pgd_perceptual_loss)
        
        adv_images = pgd_attack_obj.attack(batch.cuda(), labels.cuda(), 
                                           l_inf_bound =8.0/255.0, 
                                           step_size=1.0/255.0, 
                                           num_iterations=16, verbose=False)
        
        # push things to cpu (in hopes it gets them out of the cache)
        # also delete everything and be sure to collect garbage before next mb
        batch.cpu()
        labels.cpu()
        del adv_images
        del batch 
        del labels 
        del pgd_attack_obj 
        del pgd_perceptual_loss
        del cifar_normer
        adv_trained_net.cpu()
        del adv_trained_net 
        del base_model 
        
        
    return
    
    
##########################################################################
#   BREAK THE PLANET BLOCK                                               #
##########################################################################
print memout_example()

Hopefully the annotations make things clear, but gist is that I’m running adversarial attacks across many minibatches from CIFAR. For each minibatch, however, I’m deleting all references to everything except the loop-variables and then reinitializing. It’s my understanding that if I delete references, then garbage collect, and then call torch.cuda.empty_cache() the CUDA memory allocated by the last minibatch should be cleared out.

However, this is not what I’m witnessing. My output looks like:

Files already downloaded and verified
BATCH NUMBER: 0
GPU MEMORY: 554.0
BATCH NUMBER: 1
GPU MEMORY: 2306.0
BATCH NUMBER: 2
GPU MEMORY: 3896.0
BATCH NUMBER: 3
GPU MEMORY: 5484.0
BATCH NUMBER: 4
GPU MEMORY: 7074.0
BATCH NUMBER: 5
GPU MEMORY: 8664.0
BATCH NUMBER: 6
GPU MEMORY: 10252.0

Until I get an error that looks like
RuntimeError: cuda runtime error (2) : out of memory at /build/python-pytorch/src/pytorch-0.3.1-py2-cuda/torch/lib/THC/generic/THCStorage.cu:58

So somehow, despite aggressively trying to clear CUDA memory, things accumulate and eventually I run out of memory.

I’m happy to share more of my code or host a live jupyter notebook to demonstrate the issue.

using volatile = True for those variables that are not used for training might help. Do you use that ?

val_batch = Variable(val_batch.cuda(),volatile=True)

In general I don’t explicitly set volatile to be True when creating variables, but I’m also taking gradients with respect to the inputs (Variable(batch.cuda(), requires_grad=True) in this case), so those variables can’t be volatile. The only things that I could make volatile are the weights of the model itself. From this thread (Why can’t model parameters be variables?) it seems that I can’t do that, since model parameters have requires_grad true by default.

Though I also don’t think it should matter, so long as all memory is being freed after each minibatch, which it should be: if I can do one minibatch and then free all memory (thereby returning to the original state), I should be able to do all minibatches.

Do you reload the model etc. in the val_loader loop on purpose?
Could you move it in front of the loop and check again, if the memory is increasing?

It might be, you are holding some references to the model or other objects on the GPU in one of the “init methods” like plf.PerceptualXentropy or aa.LInfPGD. Thus this memory might be collected, since PyTorch cannot free it. Could you check that or give some info on the implementation of these methods?

Do you reload the model etc. in the val_loader loop on purpose?

Yep. I wouldn’t run code like this in practice, but I just wanted to demonstrate that I’m aggressively deleting/deallocating memory in python’s eyes.

Could you move it in front of the loop and check again, if the memory is increasing?

This is the way I typically run things, but to sanity check, I verified that the memory is still increasing (each minibatch is just processed more quickly since there’s less in-loop overhead)

Could you check that or give some info on the implementation of these methods?

Sure. You can take a look at the whole github repository, but I can quickly give an overview:

  • plf.PartialXentropy(...) is an object that has attributes that points to the neural net adv_trained_net, as well as another implicitly defined neural net
  • aa.LInfPGD is also an object that points to the neural net adv_trained_net as well as the plf.PartialXentropy(...) object
  • aa.LInfPGD(...).attack(...) iteratively updates a Variable that is initalized to have batch.data by performing gradient ascent on the loss with respect to the inputs.

So I guess my understanding was that as long as python doesn’t have a reference to an object and I call try to clear the cuda cache, then any pytorch-initialized objects should be deallocated, but this line:

Thus this memory might be collected, since PyTorch cannot free it.

Suggests that maybe that’s not the case?

I’ll try to find a cleaner example to recreate that doesn’t rely on so much application-specific code.

Ok, I see. I just skimmed your repo and tried to locate the plf functions.
I created a small example, which leads to a memory issue (also on CPU).


class PartialLoss(object):
    """ Partially applied loss object. Has forward and zero_grad methods """
    def __init__(self):
        self.nets = []

class PartialXentropy(PartialLoss):
    def __init__(self, classifier):
        super(PartialXentropy, self).__init__()
        self.classifier = classifier
        self.nets.append(self.classifier)    


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1000, 10000)
        self.fc2 = nn.Linear(10000, 10000)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

x = Variable(torch.randn(1, 1000))
model = MyModel()

output = model(x)

part_loss = PartialXentropy(model)

del model
del part_loss

Could you run this code and check your memory?
Even if I delete the model and the part_loss, which holds a reference to the model, I won’t get the memory back. Can you reproduce this?

EDIT: It seems the memory is just peaking and released afterwards. Sorry for the misleading idea.

2 Likes

I’m seeing the same things: the memory will peak and then get released.

Throwing in a time.sleep(.) call and then checking the memory (in the toy example only) has been released seems to reliably bring the allocated memory amount back to the original state.

Doing something similar in my original problem doesn’t seem to help, however. So I can safely assume that it’s not just a problem with some delay in deallocation.

Okay I’ve found a simple way to reproduce this, as well as a fix (though admittedly I think my own inexperience with pytorch is what crippled me here).

The gist is that my loss function was returning a nonscalar Variable, so I was taking gradients using the following block

if torch.numel(loss) > 1:
  torch.autograd.backward([loss], grad_variables=[input_variables])
else:
  torch.autograd.backward(loss)

thinking that this would do what I want since loss above was an Nx1x1x1 tensor (from input variables of shape NxCxHxW) and the i^th index of the loss only depended on the i^th element of the input.

Somehow, taking backward steps from a nonscalar loss allows memory to accumulate without being released in Cuda. Simply ensuring that my loss will always return a scalar, or just passing everything through torch.sum at the end is enough to fix my problem.

Here’s some code to demo:

##########################################################################
#   IMPORT BLOCK                                                         #
##########################################################################
import gc 

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
import subprocess
import time 

# fxn taken from https://discuss.pytorch.org/t/memory-leaks-in-trans-conv/12492
def get_gpu_memory_map():   
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.used',
            '--format=csv,nounits,noheader'
        ])
    
    return float(result)


##########################################################################
#   CLASS BLOCK                                                          #
##########################################################################

class ClassA(object):
    def __init__(self):
        self.nets = [] 
        
class BadSubclass(ClassA):
    def __init__(self, classifier):
        super(BadSubclass, self).__init__()
        self.classifier = classifier 
        self.nets.append(self.classifier)
        
    def forward(self, inp):
        return self.classifier.forward(inp).squeeze()            
    
class GoodSubclass(BadSubclass):
    def __init__(self, classifier):
        super(BadSubclass, self).__init__()
        self.classifier = classifier
        self.nets.append(self.classifier)
        
    def forward(self, inp):
        return torch.sum(super(GoodSubclass, self).forward(inp))
        
        
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1000, 10000)
        self.fc2 = nn.Linear(10000, 1000)
        self.fcs = [self.fc1, self.fc2]
        
    def forward(self, x):
        for fc in self.fcs:
            x = fc(x)
            x = F.relu(x)
        return x # 1000 dimension output 

    

##########################################################################
#   EXAMPLE BLOCK                                                        #
##########################################################################

def memout_example(bad_or_good):
    # reuse all the code except for which subclass we use 
    # and which grad technique we use 
    
    
    assert bad_or_good in ['bad', 'good']
    if bad_or_good == 'bad':
        subclass = BadSubclass 
        grad_method = lambda output, inp: torch.autograd.backward(
                                          [output], grad_variables=[inp])
    else:
        subclass = GoodSubclass 
        grad_method = lambda output, inp: output.backward() 
            
    
    # Loop through, pick a random input, run it through model
    # then compute gradients, then clean up as much as possible 
        
    for i in xrange(10):    
        print "LOOP: (%s) | BASE STATE" % i, get_gpu_memory_map()
        x = Variable(torch.randn(1, 1000)).cuda()
        model = MyModel().cuda()

        example = subclass(model)
        out = example.forward(x)
        grad_method(out, x)
        print "LOOP: (%s) | PEAK STATE" % i, get_gpu_memory_map()
        del model 
        del example
        del out 
        del x 
        gc.collect()
        torch.cuda.empty_cache()
        time.sleep(5)

        print "LOOP: (%s) | OUT  STATE" % i, get_gpu_memory_map()   
        print '-' * 29 # pretty prints

Running memout_example('bad') should cause things to accumulate, but memout_example('good') should be okay.

1 Like