I have been trying to use torch data + DDP but I am running into the following error:
ValueError: DataLoader with IterableDataset: expected unspecified sampler option, but got sampler=<torch.utils.data.distributed.DistributedSampler object at 0x7fdcbf3afa30>
Since you are using IterableDataset, you should specify your sharding inside your dataset class rather than using DistributedSampler. You need to shard data based on the worker id and rank in distributed environment.
class CustomDS(IterableDataset):
def __init__(self, ...):
..,
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
total_workers = worker_info.num_workers
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
world_size = 1
rank = 0
total_workers *= world_size
global_worker_id = worker_id * world_size + rank_id
for i, d in enumerate(your_data):
if i % total_workers == global_worker_id:
yield d