One approach would be to store the already loaded data in the Dataset
and use it afterwards.
The downside is that you are limited to num_workers=0
in the first epoch.
class MyDataset(Dataset):
def __init__(self, use_cache=False):
self.data = torch.randn(100, 1)
self.cached_data = []
self.use_cache = use_cache
def __getitem__(self, index):
if not self.use_cache:
x = self.data[index] # your slow data loading
self.cached_data.append(x)
else:
x = self.cached_data[index]
return x
def set_use_cache(self, use_cache):
if use_cache:
self.cached_data = torch.stack(self.cached_data)
else:
self.cached_data = []
self.use_cache = use_cache
def __len__(self):
return len(self.data)
dataset = MyDataset(use_cache=False)
loader = DataLoader(
dataset,
num_workers=0,
shuffle=False
)
for data in loader:
print(len(loader.dataset.cached_data))
loader.dataset.set_use_cache(use_cache=True)
loader.num_workers = 2
for data in loader:
print(len(loader.dataset.cached_data))