Distributed data sampler for iterative datasets over TPUs

I have some large text stored in TFDS format, I want to run seq2seq models on them efficiently over TPUs/multiple GPUs, in the datasets with random access, one can use the following function:

def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

I am not sure how to write dataloaders allowing distributed training over TPUs when datasets are iterable. Could you provide me with examples and best way to handle large-scale datasets in pytorch?


The webpage does not allow me to edit my post, they are not anymore in TFDS format, but they are Iterable datasets in pytorch and I am looking for a way to defined distributed sampler for this case. thanks. Here is the definition of my dataloader for more info:

import numpy as np

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class DEBUG_dataset(Dataset):
    def __init__(self,alpha):
        self.d = (torch.arange(20) + 1) * alpha
    def __len__(self):
        return self.d.shape[0]
    def __getitem__(self, index):
        return self.d[index]

# https://medium.com/speechmatics/how-to-build-a-streaming-dataloader-with-pytorch-a66dd891d9dd
# https://discuss.pytorch.org/t/train-simultaneously-on-two-datasets/649/35
class MultiTaskDataloader(object):
    def __init__(self,  dataloaders, tau=1.0):
        self.dataloaders = dataloaders
        Z = sum(pow(v, tau) for v in self.dataloader_sizes.values())
        self.tasknames, self.sampling_weights = zip(*((k, pow(v, tau) / Z) for k, v in self.dataloader_sizes.items()))
        self.dataiters = {k: cycle(v) for k, v in dataloaders.items()}

    def dataloader_sizes(self):
        if not hasattr(self, '_dataloader_sizes'):
            self._dataloader_sizes = {k: len(v) for k, v in self.dataloaders.items()}
        return self._dataloader_sizes

    def __len__(self):
        return sum(v for k, v in self.dataloader_sizes.items())

    def __iter__(self):
        outputs = {}
        for i in range(len(self)):
            taskname = np.random.choice(self.tasknames, p=self.sampling_weights)
            dataiter = self.dataiters[taskname]
            outputs["batch"] = next(dataiter)
            outputs['task'] = taskname
            yield outputs

if __name__=="__main__":
    train_dl1 = DataLoader(DEBUG_dataset(10), batch_size = 4, num_workers = 0, shuffle=True)
    train_dl2 = DataLoader(DEBUG_dataset(1), batch_size = 4, num_workers = 0, shuffle=True)
    dataloader = MultiTaskDataloader({"task": train_dl1, "task2": train_dl1})
    for batch in islice(dataloader, 5):

To give more context, my end goal is to train this model(https://github.com/huggingface/transformers/blob/master/examples/seq2seq/finetune_trainer.py) in pytorch on iterable datasets, and I need to handle iterable datsets for distributed training over TPUs efficiently, (currently https://github.com/huggingface/transformers/blob/master/examples/seq2seq/seq2seq_trainer.py line 117 does not support it for iterable datasets.). thanks

Hello there.

I’m the author of pytorch-resample.

It’s difficult for me to answer precisely because I have no experience with distributed training. It seems to me that you want to do sampling on your dataset. The samplers provided in the library I just mentionned take an IterableDataset as input. They also inherit from IterableDatasets, which means that you can use them instead of an existing IterableDataset.

What exactly are you expecting? What do you need the __iter__ method of the sampler to do exactly?