I want to implement a simple form of multi-task learning. Let us say there are two tasks A and B. I want to create a dataloader such that the batches alternate between these tasks i.e. one batch should only contain sample from a single task. The first approach that I am trying is to create a dataloaders for each task in the usual way and then combine them using a MultitaskDataLoader. A POC implementation is as follows:
class MultitaskDataLoader(torch.utils.data.DataLoader):
def __init__(self, task_names, datasets):
self.task_names = task_names
self.lengths = [len(d) for d in datasets]
self.iterators = [iter(d) for d in datasets]
indices = [[i] * v for i, v in enumerate(self.lengths)]
self.task_indices = sum(indices, [])
def _reset(self):
random.shuffle(self.task_indices)
self.current_index = 0
def __iter__(self):
self._reset()
return self
def __len__(self):
return sum(self.lengths)
def __next__(self):
if self.current_index < len(self.task_indices):
task_index = self.task_indices[self.current_index]
task_name = self.task_names[task_index]
batch = next(self.iterators[task_index])
new_batch = (batch, task_name)
self.current_index += 1
return new_batch
else:
raise StopIteration
task_names = ["A", "B"]
d1 = ['task-A'] * 5
d2 = ['task-B'] * 10
dl = MultitaskDataLoader(task_names, [d1, d2])
This works as expected but it stops after every epoch. I have to create a new dataloader object when every epoch starts. We do not have to do that for the standard dataloader torch.utils.data.dataloader? Then why do I have to do it for this? What should I change in this one to make it work exactly like the standard one?