Best practice to cache the entire dataset during first epoch

Hi,

Currently, I am in a situation: the dataset is stored in a single file on a shared file system and too many processes accessing the file will cause a slow down to the file system (for example, 40 jobs each with 20 workers will end up 800 processes reading from the same file). So I plan to load the dataset to the memory.

I have enough memory (~500G) to hold the entire dataset (for example, ImageNet 1k), but loading the dataset before training is too slow. I would like to know if there is a good way to cache the entire dataset during the first epoch so that after first epoch workers will close the file and read directly from memory.

Thanks.

1 Like

One approach would be to store the already loaded data in the Dataset and use it afterwards.
The downside is that you are limited to num_workers=0 in the first epoch.

class MyDataset(Dataset):
    def __init__(self, use_cache=False):
        self.data = torch.randn(100, 1)
        self.cached_data = []
        self.use_cache = use_cache
        
    def __getitem__(self, index):
        if not self.use_cache:
            x = self.data[index] # your slow data loading
            self.cached_data.append(x)
        else:
            x = self.cached_data[index]
        
        return x
    
    def set_use_cache(self, use_cache):
        if use_cache:
            self.cached_data = torch.stack(self.cached_data)
        else:
            self.cached_data = []
        self.use_cache = use_cache
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset(use_cache=False)
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for data in loader:
    print(len(loader.dataset.cached_data))

loader.dataset.set_use_cache(use_cache=True)
loader.num_workers = 2
for data in loader:
    print(len(loader.dataset.cached_data))
6 Likes

In order to use this for images-

I tried this solution for images to which I got this error

File “/home/sagarbhonde/kiwi/kd/dataset.py”, line 90, in set_use_cache
self.cached_images = torch.stack(self.cached_images)
TypeError: stack(): argument ‘tensors’ (position 1) must be tuple of Tensors, not Tensor

So you might want to change set_use_cache method to:

    def set_use_cache(self, use_cache):
        if use_cache:
            x_img = tuple(self.cached_images)
            self.cached_images = torch.stack(x_img)
        else:
            self.cached_images = []
        self.use_cache = use_cache
1 Like

@ptrblck I was just reading this response you wrote back in 2018. I am trying to do some basic speed up on my Dataloader (loading images). Is what you wrote here still valid? I saw you also responded to a similar post here and I wasn’t sure if you have to use multiprocessing Array in most cases. I liked the simplicity of what you show here. I am doing 5 folds, 15 epochs per fold, and Test Time Augmentation x 15 on Validation and Test. So that’s a lot of time reading the same files from disk over and over, and I am hoping I can use something like this to speed things up a bit.

The posted approach here might still be valid, but note that the first pass is using num_workers=0 in order to add the loaded samples into the Dataset.
The following epochs might use more workers, where each process will create a copy of the complete Dataset, which would increase the memory usage by num_workers.

The linked post uses the shared arrays to directly paste the loaded samples into it using multiple workers.

Let me know, if you encounter any issues.

@ptrblck I did encounter some issues, I’ll post in the other thread though as its relevant to that method.

Thank you for your answer, it works great.
Can you explain why num_workers must be zero, and what happens when it is not?

The first epoch would fill the “cache” in the original Dataset object using a single worker. The other epochs would then use multiple workers and reuse this cache, since each worker would create a copy of the dataset.
Note that this is not an optimal approach, but more a proof of concept.

I placed the cache object within the dataset class as a class variable, so copies would still “see and use” the same cached object. Although this works on the console, when using num_workers > 0 it doesn’t, and copying shouldn’t cause the effect ( as much as I understanding python’s internals, and I guess I don’t ).

in your reply, you wrote: "not an optimal approach… ", what do you recommend when handling images as files?

You could use shared arrays instead of this list, which forces you to use a the main process while the cache is being filled.

I tried to compile a version I used to cache my preprocessed images. This impl. saves some memory and gives some speedup since it avoids resizing the images in each process. About the torch.multiprocessing.set_sharing_strategy('file_system') call I am not absolutely sure and maybe @ptrblck has some comments on this, but it was suggested by the PyTorch framework and without it I got an exception after 14+ hours of training when passing on the same dataloader to 30+ different model training loops.

import torch
# used at the beginning of your program
torch.multiprocessing.set_sharing_strategy('file_system')
import torchvision.transforms as transforms
from PIL import Image
from multiprocessing import Manager
from torch.utils.data import Dataset, Dataloader


class DatasetCache(object):
    def __init__(self, manager, use_cache=True):
        self.use_cache = use_cache
        self.manager = manager
        self._dict = manager.dict()

    def is_cached(self, key):
        if not self.use_cache:
            return False
        return str(key) in self._dict

    def reset(self):
        self._dict.clear()

    def get(self, key):
        if not self.use_cache:
            raise AttributeError('Data caching is disabled and get funciton is unavailable! Check your config.')
        return self._dict[str(key)]

    def cache(self, key, img, lbl):
        # only store if full data in memory is enabled
        if not self.use_cache:
            return
        # only store if not already cached
        if str(key) in self._dict:
            return
        self._dict[str(key)] = (img, lbl)


class CachedListDataset(Dataset):
    '''Load image/labels from a list file.
    The list file is like:
      a.jpg label ...
    '''
    def __init__(self, root, list_file, cache, preprocess=None, transform=None):
        '''
        Args:
          root: (str) ditectory to images.
          list_file: (str/[str]) path to index file.
          cache: (DatasetCache) shared object cache
          preprocess: (function) preprocessing function before caching.
          transform: (function) image transforms.
        '''
        self.root = root
        self.cache = cache
        self.preprocess = preprocess
        self.transform = transform
        self.fnames = []
        self.labels = []

        if isinstance(list_file, list):
            # Cat multiple list files together.
            # This is especially useful for voc07/voc12 combination.
            tmp_file = '/tmp/listfile.txt'
            os.system('cat %s > %s' % (' '.join(list_file), tmp_file))
            list_file = tmp_file

        with open(list_file) as f:
            lines = f.readlines()
            self.num_imgs = len(lines)

        for line in lines:
            splited = line.strip().split()
            self.fnames.append(splited[0])
            self.labels.append(np.array([int(splited[1])]))

    def reset_memory(self):
        self.cache.reset()

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):
        '''Load image.
        Args:
          idx: (int) image index.
        Returns:
          img: (tensor) image tensor.
          labels: (tensor) class label targets.
        '''
        if self.cache.is_cached(idx):
            img, labels = self.cache.get(idx)
        else:
            # Load image and boxes.
            fname = self.fnames[idx]
            img = None
            with Image.open(os.path.join(self.root, fname)) as i:
                if i.mode != 'RGB':
                    i = i.convert('RGB')
                img = i
            if self.preprocess:
                img = self.preprocess(img)
            labels = self.labels[idx]
            labels = torch.from_numpy(labels)
            self.cache.cache(idx, img, labels)
        if self.transform:
            img = self.transform(img)
        return img, labels


# usage
manager = Manager()
cache = DatasetCache(manager)
ds = CachedListDataset(root='/path-to-data', 
                       list_file='/path-to-list-file', 
                       cache=cache,
                       preprocess=transforms.Compose([
                           transforms.Resize(256)
                       ]), 
                       transform=transforms.Compose([
                           transforms.RandomResizedCrop((224, 224), scale=[0.7, 1.0]),
                           transforms.RandomHorizontalFlip(),
                           transforms.ColorJitter(0.25, 0.25, 0.25),
                           transforms.RandomRotation(2),
                           transforms.ToTensor(),
                           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # in case of imagenet statistics
                       ]))
dl = Dataloader(ds, ...)

Hi thank you for this thread! Could you please let me know why creating the dataset in init is not a good idea?

Not sure what you referring to, but init usually is expected to be fast and shouldn’t load too much data, as it can create bugs in case of multiprocessing. see How to share data among DataLoader processes to save memory