Dataloading with Homogeneous Batches from Pools of Data

I have several “pools” of data. For my given training task, I’d like to load data from these pools, but need to ensure that in a given batch only data from the same “pool” is returned.

The way I currently have this implemented is to create a separate dataloader for each pool and create a master dataloader that wraps the individual dataloaders. I’ve noticed this is pretty slow though, as it loses out on a lot of the optiization the pytorch dataloader does. Are there any suggestions to optimize this behvavior?

For reference, here is the code I have for the “MasterDataLoader”:

class MasterLoader:
    def __init__(self, data_files, max_active_loaders=10, **data_loader_args):
        self.current_loader = 0
        self.max_active_loaders = max_active_loaders
        self.active_loaders = []
        self.data_files = data_files
        self.data_loader_args = data_loader_args

        batch_size = data_loader_args['batch_size']
        self.len = 0

        for file in tqdm(self.data_files, desc='Loading and verifying files'):
                self.len += np.ceil(len(torch.load(file)) / batch_size)
            except EOFError:
        self.len = int(self.len)
        print(f'There are a total of {self.len} batches in an epoch.')

    def __len__(self):
        return self.len

    def __iter__(self):
        self.data_files_queue = self.data_files[:]  # reset queue
        self.active_loaders = [iter(DataLoader(AIDDataset(self.data_files_queue.pop()), **self.data_loader_args))
                               for _ in range(min(len(self.data_files_queue), self.max_active_loaders))]
        return self

    def __next__(self):
        while True:  # iterate until valid sample is returned
            if len(self.active_loaders) == 0 and len(self.data_files_queue) == 0:  # finished an epoch
                raise StopIteration

                sample = next(self.active_loaders[self.current_loader])
            except StopIteration:  # if the dataloader ran out of samples
                if len(self.data_files_queue) > 0:
                    self.active_loaders[self.current_loader] = iter(DataLoader(AIDDataset(self.data_files_queue.pop()),
                    self.active_loaders.pop(self.current_loader)  # remove loader at current index

        self.current_loader += 1
        self.current_loader = self.current_loader % len(self.active_loaders)

        return sample