DDP + Torchdata


I have been trying to use torch data + DDP but I am running into the following error:

ValueError: DataLoader with IterableDataset: expected unspecified sampler option, but got sampler=<torch.utils.data.distributed.DistributedSampler object at 0x7fdcbf3afa30>

What is the correct approach for doing this?

1 Like

cc @VitalyFedyunin and @ejguan for dataloader questions

Since you are using IterableDataset, you should specify your sharding inside your dataset class rather than using DistributedSampler. You need to shard data based on the worker id and rank in distributed environment.

class CustomDS(IterableDataset):
    def __init__(self, ...):

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id
        total_workers = worker_info.num_workers
        if dist.is_available() and dist.is_initialized():
            world_size = dist.get_world_size()
            rank = dist.get_rank()
            world_size = 1
            rank = 0
        total_workers *= world_size
        global_worker_id = worker_id * world_size + rank_id
        for i, d in enumerate(your_data):
            if i % total_workers == global_worker_id:
                yield d