Hi there,
I’m currently working on a project where I have unlabeled and labeled data, and my DataLoader should deal with both of these. It’s important that it be a DataLoader and not a Dataset because the intra-batch proportion of labeled vs unlabeled data is an important hyperparameter I would like to control.
So far my solution is to iterate over both two dataloaders where I have set the batch size according to a global batch size and a intra-batch proportion parameter. I’ve overriden the _BaseDataLoaderIter class, but I feel like this must induce some leakage or problems in multiprocessing:
lass CustomDataLoderIter(_BaseDataLoaderIter):
def __init__(self, supervised_dataloader, unsupervised_dataloader, *args, **kwargs):
super().__init__(loader = unsupervised_dataloader)
self.supervised_loader = supervised_dataloader
self.unsupervised_loader = unsupervised_dataloader
self.supervised_iter = iter(supervised_dataloader)
self.unsupervised_iter = iter(unsupervised_dataloader)
# infinite cycle the smaller dataset
def __next__(self):
try:
supervised_batch = next(self.supervised_iter)
except StopIteration:
self.supervised_iter._reset(self.supervised_loader)
supervised_batch = next(self.supervised_iter)
try:
unsupervised_batch = next(self.unsupervised_iter)
except StopIteration:
self.unsupervised_iter._reset(self.unsupervised_loader)
unsupervised_batch = next(self.unsupervised_iter)
## collate dictionaries into a single dictionary
batch = {}
for key in supervised_batch:
batch[key] = torch.cat([supervised_batch[key], unsupervised_batch[key]])
return batch
class MixedDataLoader(DataLoader):
def __init__(self, supervised_dataset, unsupervised_dataset, supervised_dataset_percentage = 1, in_batch_supervised_percentage = 0.5, batch_size = 32, *args, **kwargs):
self.supervised_dataset = supervised_dataset
self.unsupervised_dataset = unsupervised_dataset
self.supervised_dataset_percentage = supervised_dataset_percentage
self.in_batch_supervised_percentage = in_batch_supervised_percentage
self.batch_size = batch_size
self.supervised_dataset_size = len(self.supervised_dataset)
# shuffle and randomly keep supervised_dataset_percentage of the supervised dataset
self.supervised_dataset_indices = np.random.permutation(self.supervised_dataset_size)[:int(self.supervised_dataset_size * self.supervised_dataset_percentage)]
self.supervised_dataset = torch.utils.data.Subset(self.supervised_dataset, self.supervised_dataset_indices)
self.supervised_batch_size = int(self.batch_size * self.in_batch_supervised_percentage)
self.unsupervised_batch_size = self.batch_size - self.supervised_batch_size
self.supervised_dataloader = DataLoader(self.supervised_dataset, batch_size=self.supervised_batch_size, shuffle=True)
self.unsupervised_dataloader = DataLoader(self.unsupervised_dataset, batch_size=self.unsupervised_batch_size, shuffle=True)
def __iter__(self) -> _BaseDataLoaderIter:
return CustomDataLoderIter(self.supervised_dataloader, self.unsupervised_dataloader)
def __len__(self):
return max(len(self.supervised_dataloader), len(self.unsupervised_dataloader))
Any help towards making the least breaking solution to this would help. I use pytorch-Lightning and appreciate the automatic multiprocessing / multiGPU, so I would love not having to revert to manual just for this.
Thanks!