Best practice to cache the entire dataset during first epoch


Currently, I am in a situation: the dataset is stored in a single file on a shared file system and too many processes accessing the file will cause a slow down to the file system (for example, 40 jobs each with 20 workers will end up 800 processes reading from the same file). So I plan to load the dataset to the memory.

I have enough memory (~500G) to hold the entire dataset (for example, ImageNet 1k), but loading the dataset before training is too slow. I would like to know if there is a good way to cache the entire dataset during the first epoch so that after first epoch workers will close the file and read directly from memory.


1 Like

One approach would be to store the already loaded data in the Dataset and use it afterwards.
The downside is that you are limited to num_workers=0 in the first epoch.

class MyDataset(Dataset):
    def __init__(self, use_cache=False): = torch.randn(100, 1)
        self.cached_data = []
        self.use_cache = use_cache
    def __getitem__(self, index):
        if not self.use_cache:
            x =[index] # your slow data loading
            x = self.cached_data[index]
        return x
    def set_use_cache(self, use_cache):
        if use_cache:
            self.cached_data = torch.stack(self.cached_data)
            self.cached_data = []
        self.use_cache = use_cache
    def __len__(self):
        return len(

dataset = MyDataset(use_cache=False)
loader = DataLoader(

for data in loader:

loader.num_workers = 2
for data in loader:

In order to use this for images-

I tried this solution for images to which I got this error

File “/home/sagarbhonde/kiwi/kd/”, line 90, in set_use_cache
self.cached_images = torch.stack(self.cached_images)
TypeError: stack(): argument ‘tensors’ (position 1) must be tuple of Tensors, not Tensor

So you might want to change set_use_cache method to:

    def set_use_cache(self, use_cache):
        if use_cache:
            x_img = tuple(self.cached_images)
            self.cached_images = torch.stack(x_img)
            self.cached_images = []
        self.use_cache = use_cache
1 Like

@ptrblck I was just reading this response you wrote back in 2018. I am trying to do some basic speed up on my Dataloader (loading images). Is what you wrote here still valid? I saw you also responded to a similar post here and I wasn’t sure if you have to use multiprocessing Array in most cases. I liked the simplicity of what you show here. I am doing 5 folds, 15 epochs per fold, and Test Time Augmentation x 15 on Validation and Test. So that’s a lot of time reading the same files from disk over and over, and I am hoping I can use something like this to speed things up a bit.

The posted approach here might still be valid, but note that the first pass is using num_workers=0 in order to add the loaded samples into the Dataset.
The following epochs might use more workers, where each process will create a copy of the complete Dataset, which would increase the memory usage by num_workers.

The linked post uses the shared arrays to directly paste the loaded samples into it using multiple workers.

Let me know, if you encounter any issues.

@ptrblck I did encounter some issues, I’ll post in the other thread though as its relevant to that method.

Thank you for your answer, it works great.
Can you explain why num_workers must be zero, and what happens when it is not?

The first epoch would fill the “cache” in the original Dataset object using a single worker. The other epochs would then use multiple workers and reuse this cache, since each worker would create a copy of the dataset.
Note that this is not an optimal approach, but more a proof of concept.

I placed the cache object within the dataset class as a class variable, so copies would still “see and use” the same cached object. Although this works on the console, when using num_workers > 0 it doesn’t, and copying shouldn’t cause the effect ( as much as I understanding python’s internals, and I guess I don’t ).

in your reply, you wrote: "not an optimal approach… ", what do you recommend when handling images as files?

You could use shared arrays instead of this list, which forces you to use a the main process while the cache is being filled.

I tried to compile a version I used to cache my preprocessed images. This impl. saves some memory and gives some speedup since it avoids resizing the images in each process. About the torch.multiprocessing.set_sharing_strategy('file_system') call I am not absolutely sure and maybe @ptrblck has some comments on this, but it was suggested by the PyTorch framework and without it I got an exception after 14+ hours of training when passing on the same dataloader to 30+ different model training loops.

import torch
# used at the beginning of your program
import torchvision.transforms as transforms
from PIL import Image
from multiprocessing import Manager
from import Dataset, Dataloader

class DatasetCache(object):
    def __init__(self, manager, use_cache=True):
        self.use_cache = use_cache
        self.manager = manager
        self._dict = manager.dict()

    def is_cached(self, key):
        if not self.use_cache:
            return False
        return str(key) in self._dict

    def reset(self):

    def get(self, key):
        if not self.use_cache:
            raise AttributeError('Data caching is disabled and get funciton is unavailable! Check your config.')
        return self._dict[str(key)]

    def cache(self, key, img, lbl):
        # only store if full data in memory is enabled
        if not self.use_cache:
        # only store if not already cached
        if str(key) in self._dict:
        self._dict[str(key)] = (img, lbl)

class CachedListDataset(Dataset):
    '''Load image/labels from a list file.
    The list file is like:
      a.jpg label ...
    def __init__(self, root, list_file, cache, preprocess=None, transform=None):
          root: (str) ditectory to images.
          list_file: (str/[str]) path to index file.
          cache: (DatasetCache) shared object cache
          preprocess: (function) preprocessing function before caching.
          transform: (function) image transforms.
        self.root = root
        self.cache = cache
        self.preprocess = preprocess
        self.transform = transform
        self.fnames = []
        self.labels = []

        if isinstance(list_file, list):
            # Cat multiple list files together.
            # This is especially useful for voc07/voc12 combination.
            tmp_file = '/tmp/listfile.txt'
            os.system('cat %s > %s' % (' '.join(list_file), tmp_file))
            list_file = tmp_file

        with open(list_file) as f:
            lines = f.readlines()
            self.num_imgs = len(lines)

        for line in lines:
            splited = line.strip().split()

    def reset_memory(self):

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):
        '''Load image.
          idx: (int) image index.
          img: (tensor) image tensor.
          labels: (tensor) class label targets.
        if self.cache.is_cached(idx):
            img, labels = self.cache.get(idx)
            # Load image and boxes.
            fname = self.fnames[idx]
            img = None
            with, fname)) as i:
                if i.mode != 'RGB':
                    i = i.convert('RGB')
                img = i
            if self.preprocess:
                img = self.preprocess(img)
            labels = self.labels[idx]
            labels = torch.from_numpy(labels)
            self.cache.cache(idx, img, labels)
        if self.transform:
            img = self.transform(img)
        return img, labels

# usage
manager = Manager()
cache = DatasetCache(manager)
ds = CachedListDataset(root='/path-to-data', 
                           transforms.RandomResizedCrop((224, 224), scale=[0.7, 1.0]),
                           transforms.ColorJitter(0.25, 0.25, 0.25),
                           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # in case of imagenet statistics
dl = Dataloader(ds, ...)

Hi thank you for this thread! Could you please let me know why creating the dataset in init is not a good idea?