Model or graph not being removed from cuda memory

I’m running into cuda memory issues arising from my model and/or the graph not being destroyed or removed from cuda memory despite my best efforts. I would really appreciate any help in trying to figure out why/where the model/graph is being retained and how to ensure it’s properly being destroyed/removed from the GPU.

There’s a bit going on in my code, so apologies if it’s dense, I can try to summarize if needed

def run_cross_validation(verbose):
    val_maes = []
    for ft_dfs, test_df in splits_xferlearn(in_df, train_frac = 0.98, num_finetune_steps = 6, num_splits = 3):
        print('***********************************************************************************************')
        print('***********************************************************************************************')

        model = Net([2000, 2000, 1000, 1000, 500, 500, 250, 250, 100], len(ind_cols), len(targ_cols)).to(device)    
        optimizer = torch.optim.Adam(model.parameters(), lr = max_lr)

        test_ds = Tab_Dataset(test_df.reset_index(drop = True), ind_cols = ind_cols, primary_target = 'target', other_targets = [col for col in targ_cols if col != 'target'])
        test_dl = torch.utils.data.DataLoader(test_ds, batch_size = batch_size, shuffle = True)

        # fine tune steps
        for ft_df in ft_dfs[-2:]:
            train_ds = Tab_Dataset(ft_df.reset_index(drop = True), ind_cols = ind_cols, primary_target = 'target', other_targets = [col for col in targ_cols if col != 'target'])
            train_dl = torch.utils.data.DataLoader(train_ds, batch_size = batch_size, shuffle = False)

            scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = max_lr, steps_per_epoch = 1, epochs = num_epochs)

            _ = train_model(model = model, optimizer = optimizer, scheduler = scheduler, train_dl = train_dl, valid_dl = test_dl, num_epochs = num_epochs, alpha = 10, verbose = verbose)
            print('=====================================================')

        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = max_lr, steps_per_epoch = 1, epochs = num_epochs)

        finval_mae = train_model(model, optimizer, scheduler, train_dl = train_dl, valid_dl = test_dl, num_epochs = num_epochs, alpha = 10, verbose = verbose)

        torch.cuda.empty_cache()
        print(torch.cuda.memory_allocated())  
        train_ds.free_memory()
        test_ds.free_memory()
        del model
        torch.cuda.empty_cache()
        print(torch.cuda.memory_allocated())

        val_maes.append(finval_mae)
    return np.mean(val_maes)

def train_model(model, optimizer, scheduler, train_dl, valid_dl, num_epochs = 10, alpha = 1, verbose = True):   
    mae_loss = nn.L1Loss()
    mse_loss = nn.MSELoss()
    
    for epoch in range(num_epochs):
        t_epoch = time.time()
        total, sum_loss = 0, 0
        # train
        model.train()
        for X, y_primary, y_others in train_dl:            
            preds = model(X)
            # calculate losses
            primary_loss = mse_loss(X[:, 0] * y_primary, X[:, 0] * preds[:, 0])
            others_loss = 0
            for ii in range(1, len(targ_cols)):
                others_loss += mse_loss(X[:, 0] * y_others[:, ii - 1], X[:, 0] * preds[:, ii])
            
            loss = alpha * primary_loss + others_loss

            # zero out the gradients, perform the backpropagation step, and update the weights
            optimizer.zero_grad()
            loss.backward()        
            optimizer.step()
            
            # collect metrics
            bs = len(y_primary)
            total += bs            
            sum_loss += bs * loss.item()

        train_loss = sum_loss / total

        # validation
        total, sum_loss, sum_mae = 0, 0, 0
        sum_mask_loss = 0
        with torch.no_grad():
            model.eval()
            for X, y_primary, y_others in valid_dl:
                preds = model(X)
                # calculate losses
                primary_loss = mse_loss(X[:, 0] * y_primary, X[:, 0] * preds[:, 0])
                others_loss = 0
                for ii in range(1, len(targ_cols)):
                    others_loss += mse_loss(X[:, 0] * y_others[:, ii - 1], X[:, 0] * preds[:, ii])
            
                loss = alpha * primary_loss + others_loss

                mae = mae_loss(y_primary, preds[:, 0])
                
                # collect metrics
                bs = len(y_primary)
                total += bs            
                sum_loss += bs * loss.item()
                sum_mae += bs * mae.item()

            val_loss = sum_loss / total
            val_mae = sum_mae / total

        scheduler.step()   
  
    return val_mae

def fc_block(in_n, out_n, drops = 0.3):
    return nn.Sequential(
        nn.Linear(in_n, out_n),
        nn.BatchNorm1d(out_n),
        nn.ReLU(),
        nn.Dropout(drops)
        )

class Net(nn.Module):
    def __init__(self, arch, num_in_cols, num_out_cols, drop_base = 0.2, drop_exp = 2):
        super().__init__()
        dropouts = [drop_base / (drop_exp ** ii) for ii in reversed(range(len(arch)))]
        layer_sizes = [num_in_cols, *arch]
        fc_blocks = [fc_block(in_n, out_n, drops = dropouts[ii]) for ii, (in_n, out_n) in enumerate(zip(layer_sizes, layer_sizes[1:]))]

        self.fc_net = nn.Sequential(*fc_blocks)
        self.fin_lin = nn.Linear(layer_sizes[-1], num_out_cols)
    
    def forward(self, x):
        x = self.fc_net(x)
        return self.fin_lin(x)

class Tab_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, ind_cols, primary_target, other_targets = None):
        self.X = torch.tensor(df[ind_cols].values, dtype = torch.float32).to(device)
        self.y_primary = torch.tensor(df[primary_target].values,  dtype = torch.float32).to(device)
        if other_targets:
            self.y_others = torch.tensor(df[other_targets].values, dtype = torch.float32).to(device)
            self.use_others = True
        else:
            self.use_others = False
        
    def __len__(self): return len(self.y_primary)
    
    def __getitem__(self, idx): 
        X = self.X[idx]              
        y_primary = self.y_primary[idx]
        
        if self.use_others:
            y_others = self.y_others[idx]
            
            return X, y_primary, y_others
        else:
            return X, y_primary
        
    def free_memory(self):
        del self.X, self.y_primary, self.y_others

Before executing run_cross_validation() there’s no memory allocated on the GPU

[19]: torch.cuda.memory_allocated()
[19]: 0

Then upon running you can see that there’s still quite a lot of memory allocated on the GPU even after manually deleting the data and the model and also after the function returns

[20]: run_cross_validation(False)
***********************************************************************************************
***********************************************************************************************
=====================================================
=====================================================
639727616
171721728
***********************************************************************************************
***********************************************************************************************
=====================================================
=====================================================
615805440
173263872
***********************************************************************************************
***********************************************************************************************
=====================================================
=====================================================
590556672
172838912

[20]: 0.1630366715185881
[21]: torch.cuda.memory_allocated()
[21]: 172838912

If I then check all the tensors that have data I see all that’s left is a series of tensors and parameters in the shape of the model–meaning it’s the model and/or graph, I’m not sure.

[22]: 
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass
[22]:
<class 'torch.Tensor'> torch.Size([2000, 1213])
<class 'torch.Tensor'> torch.Size([2000, 1213])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000, 2000])
<class 'torch.Tensor'> torch.Size([2000, 2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([1000, 2000])
<class 'torch.Tensor'> torch.Size([1000, 2000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000, 1000])
<class 'torch.Tensor'> torch.Size([1000, 1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([500, 1000])
<class 'torch.Tensor'> torch.Size([500, 1000])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500, 500])
<class 'torch.Tensor'> torch.Size([500, 500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([250, 500])
<class 'torch.Tensor'> torch.Size([250, 500])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250, 250])
<class 'torch.Tensor'> torch.Size([250, 250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([100, 250])
<class 'torch.Tensor'> torch.Size([100, 250])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([21, 100])
<class 'torch.Tensor'> torch.Size([21, 100])
<class 'torch.Tensor'> torch.Size([21])
<class 'torch.Tensor'> torch.Size([21])
<class 'torch.Tensor'> torch.Size([2000, 1213])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000, 2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([2000])
<class 'torch.Tensor'> torch.Size([1000, 2000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000, 1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([1000])
<class 'torch.Tensor'> torch.Size([500, 1000])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500, 500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([500])
<class 'torch.Tensor'> torch.Size([250, 500])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250, 250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([250])
<class 'torch.Tensor'> torch.Size([100, 250])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([100])
<class 'torch.Tensor'> torch.Size([21, 100])
<class 'torch.Tensor'> torch.Size([21])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000, 1213])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000, 2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000, 2000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000, 1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([500, 1000])
<class 'torch.nn.parameter.Parameter'> torch.Size([500])
<class 'torch.nn.parameter.Parameter'> torch.Size([500])
<class 'torch.nn.parameter.Parameter'> torch.Size([500])
<class 'torch.nn.parameter.Parameter'> torch.Size([500, 500])
<class 'torch.nn.parameter.Parameter'> torch.Size([500])
<class 'torch.nn.parameter.Parameter'> torch.Size([500])
<class 'torch.nn.parameter.Parameter'> torch.Size([500])
<class 'torch.nn.parameter.Parameter'> torch.Size([250, 500])
<class 'torch.nn.parameter.Parameter'> torch.Size([250])
<class 'torch.nn.parameter.Parameter'> torch.Size([250])
<class 'torch.nn.parameter.Parameter'> torch.Size([250])
<class 'torch.nn.parameter.Parameter'> torch.Size([250, 250])
<class 'torch.nn.parameter.Parameter'> torch.Size([250])
<class 'torch.nn.parameter.Parameter'> torch.Size([250])
<class 'torch.nn.parameter.Parameter'> torch.Size([250])
<class 'torch.nn.parameter.Parameter'> torch.Size([100, 250])
<class 'torch.nn.parameter.Parameter'> torch.Size([100])
<class 'torch.nn.parameter.Parameter'> torch.Size([100])
<class 'torch.nn.parameter.Parameter'> torch.Size([100])
<class 'torch.nn.parameter.Parameter'> torch.Size([21, 100])
<class 'torch.nn.parameter.Parameter'> torch.Size([21])

Why is the model/graph still on the GPU and what can I do to get it off? Thanks so much to anybody who can help.

I would start debugging the issue by removing the lists which append any tensors, which might not be detached as this could keep the computation graph alive and disallow a complete model deletion.
If this doesn’t help, could you please post a minimal, executable code snippet to reproduce the issue?

Ok I figured it out. It’s the scheduler causing the memory not to be released from the GPU. Seems like this was already identified as a bug and fixed in a previous version, but maybe is back somehow? Memory leak in ReduceLROnPlateau ? · Issue #17630 · pytorch/pytorch · GitHub

Here’s a minimal example showing the behavior. I’ll make an issue on GitHub if others agree it’s a bug and I’m not missing something…

torch.__version__ returns 1.11.0+cu102

import torch
from torchvision.models import resnet18
import gc

def testfun():
    model = resnet18().to('cuda')
    optimizer = torch.optim.Adam(model.parameters())

print(f'Before function memory: {torch.cuda.memory_allocated()}')
testfun()
print(f'After function memory: {torch.cuda.memory_allocated()}')
>>>
Before function memory: 0
After function memory: 0

If we then add in a scheduler the memory doesn’t get released when the function closes

def testfun():
    model = resnet18().to('cuda')
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 1, total_steps = 10)
print(f'Before function memory: {torch.cuda.memory_allocated()}')
testfun()
print(f'After function memory: {torch.cuda.memory_allocated()}')
>>>
Before function memory: 0
After function memory: 46810112

To resolve this you just have to delete the scheduler inside the function then garbage collect

def testfun():
    model = resnet18().to('cuda')
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 1, total_steps = 10)
    del scheduler

print(f'Before function memory: {torch.cuda.memory_allocated()}')
testfun()
gc.collect()
print(f'After function memory: {torch.cuda.memory_allocated()}')
>>>
Before function memory: 0
After function memory: 0

Thanks for sharing these information.
I’m unsure if the scheduler is indeed the problematic object here as another major difference between second and third approach is the call to the garbage collector.
If you add it to the second code snippet the memory will also be cleared so the scheduler is indeed just a dead object. Also allocating any new tensor frees the memory as well:

def testfun():
    model = resnet18().to('cuda')
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 1, total_steps = 10)
print(f'Before function memory: {torch.cuda.memory_allocated()}')
# Before function memory: 0
testfun()
print(f'After function memory: {torch.cuda.memory_allocated()}')
# After function memory: 46810112

x = torch.randn(10, device='cuda')
print(f'After creating x memory: {torch.cuda.memory_allocated()}')
# After creating x memory: 512