Distributed data sampler for training iterable datasets on TPUs

I have an iterable dataset and I need to define distributed datasampler for it to train efficiently on TPUs, here is the example distributed sampler for TPUs in case of non-iterable datasets, could you assist me please with providing an example with an iterable dataset (like tf.data.Dataset) which gets convert to an iterable datasets in pytorch and a distributed sampler which can be used on TPUs. thank you.

def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

cc @ailzhang for XLA
cc @VitalyFedyunin for data loader

Hi @Rabeeh_Karimi, for all pytorch/xla(TPU) related questions, please open an issue in pytorch/xla github repo instead. https://github.com/pytorch/xla Thanks!