Combine multiple datalaoders for Multi Task Learning

I want to implement a simple form of multi-task learning. Let us say there are two tasks A and B. I want to create a dataloader such that the batches alternate between these tasks i.e. one batch should only contain sample from a single task. The first approach that I am trying is to create a dataloaders for each task in the usual way and then combine them using a MultitaskDataLoader. A POC implementation is as follows:

class MultitaskDataLoader(

    def __init__(self, task_names, datasets):
        self.task_names = task_names
        self.lengths = [len(d) for d in datasets]
        self.iterators = [iter(d) for d in datasets]
        indices = [[i] * v for i, v in enumerate(self.lengths)]
        self.task_indices = sum(indices, [])

    def _reset(self):
        self.current_index = 0

    def __iter__(self):
        return self

    def __len__(self):
        return sum(self.lengths)

    def __next__(self):
        if self.current_index < len(self.task_indices):
            task_index = self.task_indices[self.current_index]
            task_name = self.task_names[task_index]
            batch = next(self.iterators[task_index])
            new_batch = (batch, task_name)
            self.current_index += 1
            return new_batch
            raise StopIteration

task_names = ["A", "B"]
d1 = ['task-A'] * 5
d2 = ['task-B'] * 10
dl = MultitaskDataLoader(task_names, [d1, d2])

This works as expected but it stops after every epoch. I have to create a new dataloader object when every epoch starts. We do not have to do that for the standard dataloader Then why do I have to do it for this? What should I change in this one to make it work exactly like the standard one?

I actually found my mistake. I should be initializing the iterators in iter function because that is what a for loop calls every time. I will leave this question to help others who might have a similar problem.