Yeah, I understand the issue and stumbled myself a few times over it.
I think one possible approach would be to use shared memory in Python e.g. with multiprocessing.Array.
You could initialize an array of your known size for the complete Dataset, fill it in the first iteration using all workers, and finally switch a flag indicating the cache/shared memory should be used in all following epochs.
I’ve created a small dummy example showing this behavior.
Currently the shared memory will be filled with torch.randn, so in this line of code you can add your heavy loading function.
import torch
from torch.utils.data import Dataset, DataLoader
import ctypes
import multiprocessing as mp
import numpy as np
class MyDataset(Dataset):
def __init__(self):
shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w)
shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
shared_array = shared_array.reshape(nb_samples, c, h, w)
self.shared_array = torch.from_numpy(shared_array)
self.use_cache = False
def set_use_cache(self, use_cache):
self.use_cache = use_cache
def __getitem__(self, index):
if not self.use_cache:
print('Filling cache for index {}'.format(index))
# Add your loading logic here
self.shared_array[index] = torch.randn(c, h, w)
x = self.shared_array[index]
return x
def __len__(self):
return nb_samples
nb_samples, c, h, w = 10, 3, 24, 24
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=2,
shuffle=False
)
for epoch in range(2):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.shape))
if epoch==0:
loader.dataset.set_use_cache(True)
Let me know, if this would work for you.