Combining DataLoaders with variable batch size

Hi there,

I’m currently working on a project where I have unlabeled and labeled data, and my DataLoader should deal with both of these. It’s important that it be a DataLoader and not a Dataset because the intra-batch proportion of labeled vs unlabeled data is an important hyperparameter I would like to control.

So far my solution is to iterate over both two dataloaders where I have set the batch size according to a global batch size and a intra-batch proportion parameter. I’ve overriden the _BaseDataLoaderIter class, but I feel like this must induce some leakage or problems in multiprocessing:

lass CustomDataLoderIter(_BaseDataLoaderIter):
    def __init__(self, supervised_dataloader, unsupervised_dataloader, *args, **kwargs):
        super().__init__(loader = unsupervised_dataloader)
        self.supervised_loader = supervised_dataloader
        self.unsupervised_loader = unsupervised_dataloader
        self.supervised_iter = iter(supervised_dataloader)
        self.unsupervised_iter = iter(unsupervised_dataloader)
        
        # infinite cycle the smaller dataset
         
        
    def __next__(self):
        try:
            supervised_batch = next(self.supervised_iter)
        except StopIteration:
            self.supervised_iter._reset(self.supervised_loader)
            supervised_batch = next(self.supervised_iter)

        try:
            unsupervised_batch = next(self.unsupervised_iter)
        except StopIteration:
            self.unsupervised_iter._reset(self.unsupervised_loader)
            unsupervised_batch = next(self.unsupervised_iter)      
        ## collate dictionaries into a single dictionary
        
        batch = {}
        for key in supervised_batch:
            batch[key] = torch.cat([supervised_batch[key], unsupervised_batch[key]])
        
        return batch

class MixedDataLoader(DataLoader):
    
    def __init__(self, supervised_dataset, unsupervised_dataset, supervised_dataset_percentage = 1, in_batch_supervised_percentage = 0.5, batch_size = 32, *args, **kwargs):
        
        self.supervised_dataset = supervised_dataset
        self.unsupervised_dataset = unsupervised_dataset
        self.supervised_dataset_percentage = supervised_dataset_percentage
        self.in_batch_supervised_percentage = in_batch_supervised_percentage
        self.batch_size = batch_size
        
        self.supervised_dataset_size = len(self.supervised_dataset)
        # shuffle and randomly keep supervised_dataset_percentage of the supervised dataset
        self.supervised_dataset_indices = np.random.permutation(self.supervised_dataset_size)[:int(self.supervised_dataset_size * self.supervised_dataset_percentage)]
        self.supervised_dataset = torch.utils.data.Subset(self.supervised_dataset, self.supervised_dataset_indices)
        
        self.supervised_batch_size = int(self.batch_size * self.in_batch_supervised_percentage)
        self.unsupervised_batch_size = self.batch_size - self.supervised_batch_size
        
        self.supervised_dataloader = DataLoader(self.supervised_dataset, batch_size=self.supervised_batch_size, shuffle=True)
        self.unsupervised_dataloader = DataLoader(self.unsupervised_dataset, batch_size=self.unsupervised_batch_size, shuffle=True)
        
        
        
    def __iter__(self) -> _BaseDataLoaderIter:
        return CustomDataLoderIter(self.supervised_dataloader, self.unsupervised_dataloader)
    
    def __len__(self):
        return max(len(self.supervised_dataloader), len(self.unsupervised_dataloader))
        
        

Any help towards making the least breaking solution to this would help. I use pytorch-Lightning and appreciate the automatic multiprocessing / multiGPU, so I would love not having to revert to manual just for this.

Thanks!