I want to implement a simple form of multi-task learning. Let us say there are two tasks A and B. I want to create a dataloader such that the batches alternate between these tasks i.e. one batch should only contain sample from a single task. The first approach that I am trying is to create a dataloaders for each task in the usual way and then combine them using a MultitaskDataLoader. A POC implementation is as follows:
class MultitaskDataLoader(torch.utils.data.DataLoader): def __init__(self, task_names, datasets): self.task_names = task_names self.lengths = [len(d) for d in datasets] self.iterators = [iter(d) for d in datasets] indices = [[i] * v for i, v in enumerate(self.lengths)] self.task_indices = sum(indices, ) def _reset(self): random.shuffle(self.task_indices) self.current_index = 0 def __iter__(self): self._reset() return self def __len__(self): return sum(self.lengths) def __next__(self): if self.current_index < len(self.task_indices): task_index = self.task_indices[self.current_index] task_name = self.task_names[task_index] batch = next(self.iterators[task_index]) new_batch = (batch, task_name) self.current_index += 1 return new_batch else: raise StopIteration task_names = ["A", "B"] d1 = ['task-A'] * 5 d2 = ['task-B'] * 10 dl = MultitaskDataLoader(task_names, [d1, d2])
This works as expected but it stops after every epoch. I have to create a new dataloader object when every epoch starts. We do not have to do that for the standard dataloader torch.utils.data.dataloader? Then why do I have to do it for this? What should I change in this one to make it work exactly like the standard one?