CUDA OOM Error, Memory Allocation Keeps Increasing Every Epoch

Hello I have been trying to get this model to run on a computer vision task and keep getting the usual out-of-memory error. The GPU memory always fills up at the 6th epoch no matter the batch_size value or whatever else I try. What I tried:

  • using del outputs, loss
  • loss.detach().item()
  • gc.collect() and torch.cuda.empty_cache()
    None of this works,

My training loop definition

def train_model(model, optimizer, criterion=loss_func, metric=dice_metric, n_epochs=20, batch_size=BATCH_SIZE):
    
    model.to(DEVICE)

    # Defining optimizer, loss, and dataloader
    train_set = HubMAPDataset(df=train_df, fold=fold, train=True) 
    val_set = HubMAPDataset(df=train_df, fold=fold, train=False) 
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_set, batch_size=batch_size)
    
    optimizer = optimizer([
                    {'params': model.encoder.parameters(), 'lr': 8e-5},
                    {'params': model.decoder.parameters(), 'lr': 5e-5}
                          ])
    scheduler = OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e3, 
                                    max_lr=1e-3, epochs=n_epochs, steps_per_epoch=len(train_loader))
    
    result = None
    best_score = 0
    val_scores, train_scores, val_losses, train_losses = [], [], [], []
    best_val_epoch = -1
    
    print(f"Starting Training")
    for epoch in range(n_epochs):
    ########################
    #      TRAINING        #
    ########################
        gc.collect()
        torch.cuda.empty_cache()
        
        model.train()
        epoch_loss, epoch_score = 0, 0
        t = tqdm(train_loader, leave=False)
        for images, labels in t:
        #####################################
        ##             EPOCH              ##
        #####################################
            optimizer.zero_grad()
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            outputs = model(images)
           
            loss = criterion(outputs, labels)
            epoch_score += metric(outputs, labels)
            epoch_loss += loss.detach().item()
            loss.backward()
            del loss, outputs
            optimizer.step()
            scheduler.step()
            
        # Statistics Recording
        epoch_loss /= len(train_loader)
        epoch_score /= len(train_loader)
        train_losses.append(epoch_loss)
        train_scores.append(epoch_score)
        
        
        if epoch%5 != 0: 
            print(f"FOLD: {fold}, EPOCH: {epoch + 1}, train_loss: {epoch_loss} , training dice: {epoch_score}") 

        #######################################
        ###           VALIDATION            ###
        #######################################
        if epoch%4 == 0 and epoch != 0:
            model.eval()
            with torch.no_grad():
                valid_loss, val_score = 0, 0 
                t_val = tqdm(val_loader)
                
                for val_images, val_labels in t_val:
                    val_images, val_labels = val_images.to(DEVICE), val_labels.to(DEVICE)
                    outputs = model(val_images)
                    val_score += metric(outputs, val_labels)
                    #val_loss += loss_fn(outputs, val_labels)
                
                val_score /= len(val_loader)
                val_scores.append(val_score)
                
                if val_score > best_score:
                    best_score = val_score
                    torch.save(model.state_dict(), f"{MODEL_NAME}_{ENCODER}-{IMG_SIZE}x{IMG_SIZE}_BestBaseline_{fold}_bsize-{BATCH_SIZE}.pth")
                    print(f"Saving model with best val score : {MODEL_NAME}_{ENCODER}-{IMG_SIZE}x{IMG_SIZE}_BestBaseline_{fold}_bsize-{BATCH_SIZE}.pth")
                
                print(f"FOLD: {fold}, EPOCH: {epoch + 1}")
                print(f"{'#'*30} Validation {'#'*100}")
                print(f"{'#'*30} Train_loss: {epoch_loss} , Train_dice: {epoch_score}, Val dice: {val_score} {'#'*25}")
        print(f"Memory cached in GPU: {torch.cuda.memory_cached()}")

I tracked memory cached in GPU for every epoch. I have 16GB of GPU memory

Memory cached in GPU: 5486149632
Memory cached in GPU: 7537164288
Memory cached in GPU: 9479127040
Memory cached in GPU: 11607736320
Memory cached in GPU: 13600030720
Memory cached in GPU: 15669919744

Then the error

RuntimeError: CUDA out of memory. Tried to allocate 72.00 MiB (GPU 0; 15.90 GiB total capacity; 12.07 GiB already allocated; 35.75 MiB free; 15.10 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Any suggestion would be greatly appreciated.

2 Likes

I guess you are keeping the computation graph alive and are attaching it in each iteration here:

epoch_score += metric(outputs, labels)

Check if metric is returning a tensor with a valid grad_fn and if so, detach this tensor before accumulating it.

2 Likes

I checked, only loss has grad_fn=<AddBackward0>

Edit: Weird thing putting metric(output, labels).item() did for some reason. I tested with the loss to see what result I would get with epoch_loss += loss and epoch_loss += loss.item() but none of these make a difference; the memory consumption is stable at 4.1 GB. If I do epoch_score += metric(outputs, labels) instead of epoch_score += metric(outputs, labels).item() my memory gets eaten up.
This does not make sense to me. Isn’t the loss the variable that is supposed to be added to the computation graph here, just on account on it being the variable I call .backward() on?

I would really like to understand what happened
This is my Metric:

class DiceCoeff(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, y_pred, y_true, smooth=1.):
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)
        
        #Round off y_pred
        y_pred = torch.round((y_pred - y_pred.min()) / (y_pred.max() - y_pred.min()))
        
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum()
        dice = (2.0 * intersection + smooth)/(union + smooth)
        
        return dice

This is my loss function

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss,self).__init__()
        self.diceloss = smp.losses.DiceLoss(mode='binary')
        self.binloss = smp.losses.SoftBCEWithLogitsLoss(reduction = 'mean' , smooth_factor = 0.1)

    def forward(self, outputs, mask):
        dice = self.diceloss(outputs,mask)
        bce = self.binloss(outputs , mask)
        loss = dice * 0.7 + bce * 0.3
        return loss

Is there something there that could explain this?

.item() effectively treats the value of the tensor as “just a python number” so that when you are adding it to another value it is just treated as a number rather than a torch.Tensor that autograd cares about. PyTorch itself by default in eager-mode doesn’t have any knowledge of what you are going to call backward() on, so it keeps a graph around for differentiable operations in case backward() is called later. (Consider that it did not know that about your CustomLoss module until backward().) In typical use cases this works well and is convenient for things like dynamism/control-flow in models, but it also means that operations that aren’t detached from the graph that accumulate will consume more and more memory as activations are saved.

4 Likes

@eqy explained the underlying mechanism and the reason for the increase in memory and you are also correct that both methods are increasing the memory usage as seen in this code:

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

class DiceCoeff(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, y_pred, y_true, smooth=1.):
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)
        #Round off y_pred
        y_pred = torch.round((y_pred - y_pred.min()) / (y_pred.max() - y_pred.min()))
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum()
        dice = (2.0 * intersection + smooth)/(union + smooth)        
        return dice


class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss,self).__init__()
        self.diceloss = smp.losses.DiceLoss(mode='binary')
        self.binloss = smp.losses.SoftBCEWithLogitsLoss(reduction = 'mean' , smooth_factor = 0.1)

    def forward(self, outputs, mask):
        dice = self.diceloss(outputs,mask)
        bce = self.binloss(outputs , mask)
        loss = dice * 0.7 + bce * 0.3
        return loss


print('Before init, mem allocated {:.3f}MB'.format(torch.cuda.memory_allocated()/1024**2))
metric = DiceCoeff()
criterion = CustomLoss()

model =nn.Linear(1000, 1000, device='cuda')
x = torch.randn(1000, 1000, device='cuda')
target = torch.randint(0, 2, (1000, 1000), device='cuda', dtype=torch.float32)
print('After init, mem allocated {:.3f}MB'.format(torch.cuda.memory_allocated()/1024**2))

epoch_loss, epoch_score = 0, 0
for epoch in range(50):
    output = model(x)
    epoch_loss += criterion(output, target).item()
    epoch_score += metric(output, target).item()
    print('epoch {}, mem allocated {:.3f}MB'.format(epoch, torch.cuda.memory_allocated()/1024**2))

You can run it in different settings and would get:

# no item() calls
Before init, mem allocated 0.000MB
After init, mem allocated 11.449MB
epoch 0, mem allocated 31.455MB
epoch 1, mem allocated 51.459MB
epoch 2, mem allocated 71.464MB
epoch 3, mem allocated 91.469MB
epoch 4, mem allocated 111.474MB
epoch 5, mem allocated 131.479MB
...
epoch 45, mem allocated 931.674MB
epoch 46, mem allocated 951.679MB
epoch 47, mem allocated 971.684MB
epoch 48, mem allocated 991.689MB
epoch 49, mem allocated 1011.694MB

# criterion.item()
Before init, mem allocated 0.000MB
After init, mem allocated 11.449MB
epoch 0, mem allocated 20.007MB
epoch 1, mem allocated 27.640MB
epoch 2, mem allocated 35.273MB
epoch 3, mem allocated 43.831MB
epoch 4, mem allocated 51.464MB
epoch 5, mem allocated 60.022MB
...
epoch 45, mem allocated 380.139MB
epoch 46, mem allocated 387.772MB
epoch 47, mem allocated 395.405MB
epoch 48, mem allocated 403.963MB
epoch 49, mem allocated 411.596MB

# metric.item()
Before init, mem allocated 0.000MB
After init, mem allocated 11.449MB
epoch 0, mem allocated 26.711MB
epoch 1, mem allocated 43.823MB
epoch 2, mem allocated 59.085MB
epoch 3, mem allocated 76.197MB
epoch 4, mem allocated 91.459MB
epoch 5, mem allocated 106.721MB
...
epoch 45, mem allocated 746.799MB
epoch 46, mem allocated 763.911MB
epoch 47, mem allocated 779.173MB
epoch 48, mem allocated 796.285MB
epoch 49, mem allocated 811.547MB

# Both item()
Before init, mem allocated 0.000MB
After init, mem allocated 11.449MB
epoch 0, mem allocated 15.264MB
epoch 1, mem allocated 16.189MB
epoch 2, mem allocated 15.264MB
epoch 3, mem allocated 16.189MB
epoch 4, mem allocated 15.264MB
epoch 5, mem allocated 16.189MB
...
epoch 45, mem allocated 16.189MB
epoch 46, mem allocated 15.264MB
epoch 47, mem allocated 16.189MB
epoch 48, mem allocated 15.264MB
epoch 49, mem allocated 16.189MB
2 Likes

Thank you very much this was very helpful, one learns a lot from a single bug.

1 Like