Distributed training with iterative dataloaders

I have a large dataset with an iterative dataloader see here (it samples over multiple tasks):

class TaskDataLoader:
    """Wrapper around dataloader to keep the task names."""
    def __init__(self, task, dataset, batch_size=8,
                 collate_fn=None, drop_last=False,
                 num_workers=0, sampler=None):
        self.dataset = dataset
        self.task = task
        self.batch_size = batch_size 
        self.data_loader = DataLoader(self.dataset,
    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):
        for batch in self.data_loader:
            batch["task"] = self.task
            yield batch

# Note not to use itertools.cycle since it is
# doing some caching under the hood, resulting
# in issues in the dataloading pipeline.
# see https://docs.python.org/3.7/library/itertools.html?highlight=cycle#itertools.cycle
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class MultiTaskDataLoader:
    """Given a dictionary of task: dataset, returns a multi-task dataloader
    which uses temperature sampling to sample different datasets."""

    def __init__(self,  max_steps, tasks_to_datasets, batch_size=8, collate_fn=None,
                 drop_last=False, num_workers=0, temperature=100.0):
        # Computes a mapping from task to dataloaders.
        self.task_to_dataloaders = {}
        for task, dataset in tasks_to_datasets.items():
            dataloader = TaskDataLoader(task, dataset, batch_size,
                collate_fn=collate_fn, drop_last=drop_last, num_workers=num_workers)
            self.task_to_dataloaders.update({task: dataloader})
        self.tasknames = list(self.task_to_dataloaders.keys())

        # Computes the temperature sampling weights.
        self.sampling_weights = self.temperature_sampling(self.dataloader_sizes.values(), temperature)
        self.dataiters = {k: cycle(v) for k, v in self.task_to_dataloaders.items()}
        self.max_steps = max_steps

    def temperature_sampling(self, dataset_sizes, temp):
        total_size = sum(dataset_sizes)
        weights = np.array([(size / total_size) ** (1.0 / temp) for size in dataset_sizes])
        return weights/np.sum(weights)

    def dataloader_sizes(self):
        if not hasattr(self, '_dataloader_sizes'):
            self._dataloader_sizes = {k: len(v) for k, v in self.task_to_dataloaders.items()}
        return self._dataloader_sizes

    def __len__(self):
        return sum(v for k, v in self.dataloader_sizes.items())

    def num_examples(self):
        return sum(len(dataloader.dataset) for dataloader in self.task_to_dataloaders.values())

    def __iter__(self):
        for i in range(self.max_steps):
            taskname = np.random.choice(self.tasknames, p=self.sampling_weights)
            dataiter = self.dataiters[taskname]
            outputs = next(dataiter)
            yield outputs

I need to use it for distributed training over GPU/TPUs, for this I shard the data across the cores:

    def get_sharded_data(self, num_replicas, rank):
        """Returns the sharded data belonging to the given rank."""
        sharded_dataset_names_to_datasets = {}
        for dataset_name, dataset in self.train_dataset.items():
            sharded_data = dataset.shard(num_replicas, rank)
            sharded_dataset_names_to_datasets.update({dataset_name: sharded_data})
        self.train_dataset = sharded_dataset_names_to_datasets
        return self.train_dataset

    def get_train_dataset_shards(self):
        """In case of multiprocessing, returns the sharded data for the given rank."""
        if is_torch_tpu_available():
            if xm.xrt_world_size() > 1:
                return self.get_sharded_data(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
                return self.train_dataset
        elif self.args.local_rank != -1:
                return self.get_sharded_data(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
            return self.train_dataset

    def get_train_dataloader(self) -> DataLoader:
        Returns the training :class:`~torch.utils.data.DataLoader`.
        Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
        to distributed training if necessary) otherwise.
        Subclass and override this method if you want to inject some custom behavior.
        train_dataset = self.get_train_dataset_shards()
        return MultiTaskDataLoader(

but as you realized the single task dataloader does not have any sampler, this does NOT work with distributed training and does not make the program runs faster, could you point me to the possible issues, thanks

Can you use the DistributedSampler instead to shard and train your data in a distributed fashion?