Dataloader resets dataset state

Yeah, I understand the issue and stumbled myself a few times over it.

I think one possible approach would be to use shared memory in Python e.g. with multiprocessing.Array.
You could initialize an array of your known size for the complete Dataset, fill it in the first iteration using all workers, and finally switch a flag indicating the cache/shared memory should be used in all following epochs.
I’ve created a small dummy example showing this behavior.
Currently the shared memory will be filled with torch.randn, so in this line of code you can add your heavy loading function.

import torch
from torch.utils.data import Dataset, DataLoader

import ctypes
import multiprocessing as mp

import numpy as np


class MyDataset(Dataset):
    def __init__(self):
        shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        shared_array = shared_array.reshape(nb_samples, c, h, w)
        self.shared_array = torch.from_numpy(shared_array)
        self.use_cache = False
        
    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def __getitem__(self, index):
        if not self.use_cache:
            print('Filling cache for index {}'.format(index))
            # Add your loading logic here
            self.shared_array[index] = torch.randn(c, h, w)
        x = self.shared_array[index]
        return x
    
    def __len__(self):
        return nb_samples


nb_samples, c, h, w = 10, 3, 24, 24

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=2,
    shuffle=False
)

for epoch in range(2):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.shape))
        
    if epoch==0:
        loader.dataset.set_use_cache(True)

Let me know, if this would work for you.

10 Likes