Since you are using IterableDataset
, you should specify your sharding inside your dataset class rather than using DistributedSampler
. You need to shard data based on the worker id and rank in distributed environment.
class CustomDS(IterableDataset):
def __init__(self, ...):
..,
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
total_workers = worker_info.num_workers
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
world_size = 1
rank = 0
total_workers *= world_size
global_worker_id = worker_id * world_size + rank_id
for i, d in enumerate(your_data):
if i % total_workers == global_worker_id:
yield d