DDP + Torchdata

Since you are using IterableDataset, you should specify your sharding inside your dataset class rather than using DistributedSampler. You need to shard data based on the worker id and rank in distributed environment.

class CustomDS(IterableDataset):
    def __init__(self, ...):
        ..,

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id
        total_workers = worker_info.num_workers
        if dist.is_available() and dist.is_initialized():
            world_size = dist.get_world_size()
            rank = dist.get_rank()
        else:
            world_size = 1
            rank = 0
        total_workers *= world_size
        global_worker_id = worker_id * world_size + rank_id
        for i, d in enumerate(your_data):
            if i % total_workers == global_worker_id:
                yield d
2 Likes