Yeah, I understand the issue and stumbled myself a few times over it.
I think one possible approach would be to use shared memory in Python e.g. with multiprocessing.Array.
You could initialize an array of your known size for the complete Dataset
, fill it in the first iteration using all workers, and finally switch a flag indicating the cache/shared memory should be used in all following epochs.
I’ve created a small dummy example showing this behavior.
Currently the shared memory will be filled with torch.randn
, so in this line of code you can add your heavy loading function.
import torch
from torch.utils.data import Dataset, DataLoader
import ctypes
import multiprocessing as mp
import numpy as np
class MyDataset(Dataset):
def __init__(self):
shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w)
shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
shared_array = shared_array.reshape(nb_samples, c, h, w)
self.shared_array = torch.from_numpy(shared_array)
self.use_cache = False
def set_use_cache(self, use_cache):
self.use_cache = use_cache
def __getitem__(self, index):
if not self.use_cache:
print('Filling cache for index {}'.format(index))
# Add your loading logic here
self.shared_array[index] = torch.randn(c, h, w)
x = self.shared_array[index]
return x
def __len__(self):
return nb_samples
nb_samples, c, h, w = 10, 3, 24, 24
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=2,
shuffle=False
)
for epoch in range(2):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.shape))
if epoch==0:
loader.dataset.set_use_cache(True)
Let me know, if this would work for you.