I have several “pools” of data. For my given training task, I’d like to load data from these pools, but need to ensure that in a given batch only data from the same “pool” is returned.
The way I currently have this implemented is to create a separate dataloader for each pool and create a master dataloader that wraps the individual dataloaders. I’ve noticed this is pretty slow though, as it loses out on a lot of the optiization the pytorch dataloader does. Are there any suggestions to optimize this behvavior?
For reference, here is the code I have for the “MasterDataLoader”:
class MasterLoader:
def __init__(self, data_files, max_active_loaders=10, **data_loader_args):
self.current_loader = 0
self.max_active_loaders = max_active_loaders
self.active_loaders = []
self.data_files = data_files
self.data_loader_args = data_loader_args
batch_size = data_loader_args['batch_size']
self.len = 0
for file in tqdm(self.data_files, desc='Loading and verifying files'):
try:
self.len += np.ceil(len(torch.load(file)) / batch_size)
except EOFError:
self.data_files.remove(file)
self.len = int(self.len)
print(f'There are a total of {self.len} batches in an epoch.')
def __len__(self):
return self.len
def __iter__(self):
self.data_files_queue = self.data_files[:] # reset queue
self.active_loaders = [iter(DataLoader(AIDDataset(self.data_files_queue.pop()), **self.data_loader_args))
for _ in range(min(len(self.data_files_queue), self.max_active_loaders))]
return self
def __next__(self):
while True: # iterate until valid sample is returned
if len(self.active_loaders) == 0 and len(self.data_files_queue) == 0: # finished an epoch
raise StopIteration
try:
sample = next(self.active_loaders[self.current_loader])
break
except StopIteration: # if the dataloader ran out of samples
if len(self.data_files_queue) > 0:
self.active_loaders[self.current_loader] = iter(DataLoader(AIDDataset(self.data_files_queue.pop()),
**self.data_loader_args))
else:
self.active_loaders.pop(self.current_loader) # remove loader at current index
self.current_loader += 1
self.current_loader = self.current_loader % len(self.active_loaders)
return sample