How to save system memory during training with large dataset?

When using Pytorch to train a regression model with very large dataset (200*200*2200 image size and 10000 images in total) I found that the system memory (not GPU memory) grew during one epoch and finally the total system memory reached the size of all dataset, as if all data were loaded into system memory. (I have used DataLoader to generate data in batch and transfer the data to cuda device).

I thought after loading one batch of data the CPU would release the system memory, but it’s not the case. Can anyone tell me the how to fix this problem? The dataset is really too large and I only have limited system memory…

Are you using a custom Dataset?
If so, could you post the implementation as well as the training loop, so that we could have a look, if some tensors are (accidentally) stored?

1 Like

Thanks very much for your reply. Yes, I have used Dataset. Here is the implemention:

import torch
import os
import numpy as np
from torch.utils.data import Dataset

#training dataset
class CNNDataset(Dataset):
    def __init__(self, length, prefix, root_dir, transform=True):
        self.length = length
        self.prefix = prefix
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return int(self.length)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, 'LCT9526_nm',
                                self.prefix+str(idx)+'.npy')
        para_name = os.path.join(self.root_dir, 'LCTo9526_nm/',
                                self.prefix+str(idx)+'-y'+'.npy')
        image = np.load(img_name)
        para = np.load(para_name)
        para[0] = para[0]/2
        para[1] = para[1]/3
        para[2] = para[2]/4
        para[3] = para[3]/5
        sample = [image, para]

        if self.transform:
            sample = self.transform(sample)

        return sample

#validation dataset
class CNNDatasetv(Dataset):
    def __init__(self, length, base, prefix, root_dir, transform=True):
        self.length = length
        self.base = base
        self.prefix = prefix
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return int(self.length)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, 'LCT9526_nm',
                                self.prefix+str(idx+self.base)+'.npy')
        para_name = os.path.join(self.root_dir, 'LCTo9526_nm/',
                                self.prefix+str(idx+self.base)+'-y'+'.npy')
        image = np.load(img_name)
        para = np.load(para_name)
        para[0] = para[0]/2
        para[1] = para[1]/3
        para[2] = para[2]/4
        para[3] = para[3]/5
        sample = [image, para]

        if self.transform:
            sample = self.transform(sample)

        return sample

(Here for the structure of my data, I have to use Dataset twice for train and validation, respectively)
Here is how I generate data:

def create_datasets(batch_size):
    trainset = CNNDataset(length=lenghtr, prefix = 'Idlt-',
                                           root_dir='/scratch/zxs/',
                                           transform=transforms.Compose([
                                               ToTensor()
                                           ]))
    validateset = CNNDatasetv(length=lenghva, base=base1, prefix = 'Idlt-',
                                           root_dir='/scratch/zxs/',
                                           transform=transforms.Compose([
                                               ToTensor()
                                           ]))
    testset = CNNDatasett(length=lenghte, base=base2, prefix = 'Idlt-',
                                           root_dir='/scratch/zxs/',
                                           transform=transforms.Compose([
                                               ToTensor()
                                           ]))

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0,pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(validateset, batch_size=1,
                                          shuffle=True, num_workers=0, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False, num_workers=0)

    return train_loader, valid_loader, test_loader

And below is code for my training loop:

def train_model(patience, n_epochs, save_early_path, net, train_loader, valid_loader):
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    for epoch in range(1, n_epochs + 1):  # loop over the dataset multiple times
        net.train()
        for i, data in enumerate(train_loader, 0):
            inputs, para = data[0].to(device,non_blocking=True), data[1].to(device,non_blocking=True)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, para)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            torch.cuda.empty_cache()
        ######################
        # validate the model #
        ######################
        net.eval()
        with torch.no_grad():
            for val in valid_loader:
                inputsv, parav = val[0].to(device,non_blocking=True), val[1].to(device,non_blocking=True)
                outputsv = net(inputsv)
                loss = criterion(outputsv, parav)
                valid_losses.append(loss.item())

        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        scheduler.step(valid_loss)
        print ('scheduler',sys.getsizeof(scheduler), flush=True)

        epoch_len = len(str(n_epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')

        print(print_msg)

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        state = {
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'valid_loss': valid_loss
                }
        early_stopping(valid_loss, state, save_early_path)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    # load the last checkpoint with the best model
    checkpoint = torch.load(save_early_path)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    return avg_train_losses, avg_valid_losses

From theoretical perspective, pytorch should not squeeze all training data into memory. There could be a bug somewhere in the pipeline, thanks for posting this, and will track this post. If I were you, I would do this: instead of using data loader, I can directly load image and transfer it to a tensor and use the tensor in the training loop. After doing this, if the memory issue goes away, that means the bug exists in the dataset/data loader side. If the memory still climbs up, I will go to check if the image data gets retained somewhere…

Thank you
Sorry for the late reply. I think it’s a good idea and worth trying. I have tested that after I reduce my network to a simpler one and only output the loss.item() in my training loop for simplicity then the memory error had gone away. However, it seems that my code would consume as much memory as I allocated, which is still strange. Anyway, thanks for your reply and time again.

I’ve met a similiar problem when training on kaggle kernel. The system memory will always increase step by step and epoch by epoch, I haved checked and compared the code in my notebook with others (that works fine) but I still can’t find a reason.

I think it’s better to reply here than open up a new thread. Here is the original post containing code source, wandb log information and more: stack overflow. I post it on stack overflow because I have also used pytorch lightning so maybe it’s not a pure pytorch problem. But it’s also great if I can know that pytorch is clean in this case. Thanks :smiley: