Just define a Dataset
object, that only loads a list of files in __init__
, and loads them every time __getindex__
is called. Then, wrap it in a torch.utils.DataLoader
with multiple workers, and you’ll have your files loaded lazily in parallel.
class MyDataset(torch.utils.Dataset):
def __init__(self):
self.data_files = os.listdir('data_dir')
sort(self.data_files)
def __getindex__(self, idx):
return load_file(self.data_files[idx])
def __len__(self):
return len(self.data_files)
dset = MyDataset()
loader = torch.utils.DataLoader(dset, num_workers=8)