RAM out of memory and process killed from 2nd epoch

My RAM usage keeps on increasing after first epoch. RAM remains at 30% around 12GB usage during first epoch of train and validation. But at second epoch it keeps on rising to 100% 62GB and then the process is killed. The entire time GPU memory remains constant. Only RAM increases. It gives the following warning after process is killed:

/home/msi_55/Sowmen_2016331055/sowmen_conda_rootenv/lib/python3.7/multiprocessing/semaphore_tracker.py:144: UserWarning: semaphore_tracker: There appear to be 6 leaked semaphores to clean up at shutdown
  len(cache))

All solutions talk about detaching tensors from the computation graph. I’ve done that but it still didn’t solve the problem. This is my training code:

def train(name, df, patch_size, VAL_FOLD=0, resume=False):

    encoder = SRM_Classifer(encoder_checkpoint='weights/Changed classifier+COMBO_ALL_FULLSRM+ELA_[08|03_21|22|09].h5', freeze_encoder=True)
    model = UnetPP(encoder, num_classes=1, sampling=config.sampling, layer='end')
    
    
    SRM_FLAG=1
       
    
    train_geo_aug = albumentations.Compose(
        [
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.RandomRotate90(p=0.1),
            albumentations.ShiftScaleRotate(shift_limit=0.01, scale_limit=0.04, rotate_limit=35, p=0.25),
        ],
        additional_targets={'ela':'image'}
    )

    normalize = {
        "mean": [0.4535408213875562, 0.42862278450748387, 0.41780105499276865],
        "std": [0.2672804038612597, 0.2550410416463668, 0.29475415579144293],
    }

    transforms_normalize = albumentations.Compose(
        [
            albumentations.Normalize(mean=normalize['mean'], std=normalize['std'], always_apply=True, p=1),
            albumentations.pytorch.transforms.ToTensorV2()
        ],
        additional_targets={'ela':'image'}
    )

    # -------------------------------- CREATE DATASET and DATALOADER --------------------------
    train_dataset = DATASET(
        dataframe=df,
        mode="train",
        val_fold=VAL_FOLD,
        test_fold=TEST_FOLD,
        patch_size=patch_size,
        resize=256,
        transforms_normalize=transforms_normalize,
        geo_augment=train_geo_aug
    )
    train_loader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=False)

    valid_dataset = DATASET(
        dataframe=df,
        mode="val",
        val_fold=VAL_FOLD,
        test_fold=TEST_FOLD,
        patch_size=patch_size,
        resize=256,
        transforms_normalize=transforms_normalize,
    )
    valid_loader = DataLoader(valid_dataset, batch_size=config.valid_batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=False)

    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay )

    model = nn.DataParallel(model)
    model.to(device)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        patience=config.schedule_patience,
        mode="min",
        factor=config.schedule_factor,
    )

    criterion = losses.DiceLoss(mode='binary', log_loss=True, smooth=1e-7)
    es = EarlyStopping(patience=15, mode="min")

    start_epoch = 0
    for epoch in range(start_epoch, config.epochs):
        print(f"Epoch = {epoch}/{config.epochs-1}")
        print("------------------")

        # if epoch == 2:
        #     model.module.encoder.unfreeze()

        train_metrics = train_epoch(model, train_loader, optimizer, criterion, epoch, SRM_FLAG)

        valid_metrics = valid_epoch(model, valid_loader, criterion,  epoch)
        scheduler.step(valid_metrics["valid_loss_segmentation"])

        print(
            f"TRAIN_LOSS = {train_metrics['train_loss_segmentation']}, \
            TRAIN_DICE = {train_metrics['train_dice']}, \
            TRAIN_JACCARD = {train_metrics['train_jaccard']},"
        )
        print(
            f"VALID_LOSS = {valid_metrics['valid_loss_segmentation']}, \
            VALID_DICE = {valid_metrics['valid_dice']}, \
            VALID_JACCARD = {valid_metrics['valid_jaccard']},"
        )

        es(
            valid_metrics["valid_loss_segmentation"],
            model,
            model_path=os.path.join(OUTPUT_DIR, f"{name}_[{dt_string}].h5"),
        )
        if es.early_stop:
            print("Early stopping")
            break
    
def train_epoch(model, train_loader, optimizer, criterion, epoch, SRM_FLAG):
    model.train()

    segmentation_loss = AverageMeter()
    targets = []
    outputs = []

    for batch in tqdm(train_loader):
        images = batch["image"].to(device)
        elas = batch["ela"].to(device)
        gt = batch["mask"].to(device)

        optimizer.zero_grad()
        out_mask = model(images, elas)

        loss_segmentation = criterion(out_mask, gt)
        loss_segmentation.backward()

        optimizer.step()

        if SRM_FLAG == 1:
            bayer_mask = torch.zeros(3,3,5,5).cuda()
            bayer_mask[:,:,5//2, 5//2] = 1
            bayer_weight = model.module.encoder.bayer_conv.weight * (1-bayer_mask)
            bayer_weight = (bayer_weight / torch.sum(bayer_weight, dim=(2,3), keepdim=True)) + 1e-7
            bayer_weight -= bayer_mask
            model.module.encoder.bayer_conv.weight = nn.Parameter(bayer_weight)
            
        # ---------------------Batch Loss Update-------------------------
        segmentation_loss.update(loss_segmentation.detach().item(), train_loader.batch_size)

        with torch.no_grad():
            out_mask = torch.sigmoid(out_mask).squeeze(1)
            out_mask = out_mask.cpu().detach()
            gt = gt.cpu().detach()
            
            targets.extend(list(gt))
            outputs.extend(list(out_mask))

        gc.collect()

    print("~~~~~~~~~~~~~~~~~~~~~~~~~")
    dice, _ = seg_metrics.dice_coeff(outputs, targets) 
    jaccard, _ = seg_metrics.jaccard_coeff(outputs, targets)  
    print("~~~~~~~~~~~~~~~~~~~~~~~~~")

    train_metrics = {
        "train_loss_segmentation": segmentation_loss.avg,
        "train_dice": dice.item(),
        "train_jaccard": jaccard.item(),
        "epoch" : epoch
    }
    return train_metrics


def valid_epoch(model, valid_loader, criterion, epoch):
    model.eval()

    segmentation_loss = AverageMeter()
    targets = []
    outputs = []

    example_images = []
    image_names = []
    
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            images = batch["image"].to(device)
            elas = batch["ela"].to(device)
            gt = batch["mask"].to(device)
            
            out_mask = model(images, elas)
          
            loss_segmentation = criterion(out_mask, gt)

            # ---------------------Batch Loss Update-------------------------
            segmentation_loss.update(loss_segmentation.item(), valid_loader.batch_size)

            out_mask = torch.sigmoid(out_mask).squeeze(1)
            out_mask = out_mask.cpu().detach()
            gt = gt.cpu().detach()
            
            targets.extend(list(gt))
            outputs.extend(list(out_mask))

    print("~~~~~~~~~~~~~~~~~~~~~~~~~")       
    dice, best_dice = seg_metrics.dice_coeff(outputs, targets)  
    jaccard, best_iou = seg_metrics.jaccard_coeff(outputs, targets) 
    print("~~~~~~~~~~~~~~~~~~~~~~~~~")

    
    valid_metrics = {
        "valid_loss_segmentation": segmentation_loss.avg,
        "valid_dice": dice.item(),
        "valid_jaccard": jaccard.item(),
        "epoch" : epoch
    }
    return valid_metrics

How to solve this error? I’ve detached everything and tried everything. But I don’t understand why 1st epoch is ok but memory increases at 2nd epoch.

Does it still crash if you reduce the batch size?

Yes. I’ve tried reducing to a batch size of 4. RAM still keeps exploding.

Could you try to narrow down the issue a bit further by removing specific parts from your code and checking, if the memory is still increasing?
E.g. you could start with the (metric) logging, then the transformations, then using a single worker etc.

Okay found the error. It was happening in the lines

targets.extend(list(gt))
outputs.extend(list(out_mask))

As all predictions are stored in the ram, the space gets exhausted. But still, this doesn’t explain why the 1st epoch runs without any problems.