Best practice to cache the entire dataset during first epoch


#1

Hi,

Currently, I am in a situation: the dataset is stored in a single file on a shared file system and too many processes accessing the file will cause a slow down to the file system (for example, 40 jobs each with 20 workers will end up 800 processes reading from the same file). So I plan to load the dataset to the memory.

I have enough memory (~500G) to hold the entire dataset (for example, ImageNet 1k), but loading the dataset before training is too slow. I would like to know if there is a good way to cache the entire dataset during the first epoch so that after first epoch workers will close the file and read directly from memory.

Thanks.


#2

One approach would be to store the already loaded data in the Dataset and use it afterwards.
The downside is that you are limited to num_workers=0 in the first epoch.

class MyDataset(Dataset):
    def __init__(self, use_cache=False):
        self.data = torch.randn(100, 1)
        self.cached_data = []
        self.use_cache = use_cache
        
    def __getitem__(self, index):
        if not self.use_cache:
            x = self.data[index] # your slow data loading
            self.cached_data.append(x)
        else:
            x = self.cached_data[index]
        
        return x
    
    def set_use_cache(self, use_cache):
        if use_cache:
            self.cached_data = torch.stack(self.cached_data)
        else:
            self.cached_data = []
        self.use_cache = use_cache
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset(use_cache=False)
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for data in loader:
    print(len(loader.dataset.cached_data))

loader.dataset.set_use_cache(use_cache=True)
loader.num_workers = 2
for data in loader:
    print(len(loader.dataset.cached_data))