Memory consumption raise rapidly when increasing the num_workers in Dataloader

I am working with video datasets on PyTorch. But the memory consumption is too huge for a normal machine (15G of memory consumption with 8 workers, roughly increase 1G for an additional worker). I wonder if I have done something wrong for the dataloader or dataset. Here are my codes:

class SequenceDataset(torch.utils.data.Dataset):
    """
        Base dataset of video sequence

        Inputs:

            root : str, path to the dataset
            dataset : str, subset choose from 'train' 'val' 'test' and ''
    """
    def __init__(self, root, dataset=''):
        super().__init__()
        self.root = root
        self.dataset = dataset

        self.datapath = os.path.join(self.root, dataset)

        # load trajlist
        filenames = sorted(os.listdir(self.datapath))
        self.trajlist = [os.path.join(self.datapath, filename) for filename in filenames]

    def __len__(self):
        return len(self.trajlist)

    def __getitem__(self, index):
        traj_file = self.trajlist[index]
        output = load_npz(traj_file) # each file contain an uint8 ndarray with shape [500, 64, 64, 3], so it is roughly 6Mb 

        return output

I only use a sub-sequence of the data in training, so I write a wrapper to do it.

class Split(torch.utils.data.Dataset):
    """
        split part of the data
    """
    def __init__(self, dataset, horizon, fix_start=False):
        super().__init__()
        self._dataset = dataset
        self.horizon = horizon
        self.fix_start = fix_start

    def __getattr__(self, name):
        return getattr(self._dataset, name)

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, index):
        data = self._dataset[index]

        max_length = list(data.values())[0].shape[0]

        start = 0 if self.fix_start else random.randint(0, max_length - self.horizon)
        end = start + self.horizon

        return {k : v[start:end] for k, v in data.items()}

And the dataloader is create like:

dataset = Split(SequenceDataset(path), horizon=32, fix_start=False)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8, pin_memory=True)

It just use 15G of the memory. I only have 200 npz file in the dataset, so even if I load them all at once, they should only use 1.2G of the memory. So, I am confused with what other memory is used for?

I have just went though the code of Dataloader. The multiprocessing version will prefill 2 * num_workers mini-batches. That is if the whole batch is loaded, the memory consumption should be 2 * 8 * 32 * 500 * 3 * 64 * 64 * 4 / (1024 ** 3) = 12G. This explain the huge memory consumption. To free the additional data, i.e. outside the split, we need to copy the split data to free the pointer to the loaded ndarray. That is, instead of

return {k : v[start:end] for k, v in data.items()}

we should

return {k : v[start:end].copy() for k, v in data.items()}
1 Like