Hi
I have an iterable dataset and I need to define distributed datasampler for it to train efficiently on TPUs, here is the example distributed sampler for TPUs in case of non-iterable datasets, could you assist me please with providing an example with an iterable dataset (like tf.data.Dataset) which gets convert to an iterable datasets in pytorch and a distributed sampler which can be used on TPUs. thank you.
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())