Best practice to cache the entire dataset during first epoch

One approach would be to store the already loaded data in the Dataset and use it afterwards.
The downside is that you are limited to num_workers=0 in the first epoch.

class MyDataset(Dataset):
    def __init__(self, use_cache=False):
        self.data = torch.randn(100, 1)
        self.cached_data = []
        self.use_cache = use_cache
        
    def __getitem__(self, index):
        if not self.use_cache:
            x = self.data[index] # your slow data loading
            self.cached_data.append(x)
        else:
            x = self.cached_data[index]
        
        return x
    
    def set_use_cache(self, use_cache):
        if use_cache:
            self.cached_data = torch.stack(self.cached_data)
        else:
            self.cached_data = []
        self.use_cache = use_cache
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset(use_cache=False)
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for data in loader:
    print(len(loader.dataset.cached_data))

loader.dataset.set_use_cache(use_cache=True)
loader.num_workers = 2
for data in loader:
    print(len(loader.dataset.cached_data))
6 Likes