Motivation:
I have a large dataset split across multiple shards (separate files) on disk. I can load it into memory. I’d like to distribute training with DDP and use multiple workers in my dataloader. I’d like to have random access, i.e. a map dataset.
This will create world_size * num_workers
processes, each trying to fully load the dataset into memory. This will by far exceed memory limits. Hence I’d like each worker to handle only certain shards.
Problem:
How can I
- distribute indices across ranks according to sharding, i.e. each rank only gets assigned indices from certain shards, while indices from one shard are all assigned to exactly one rank.
- distribute indices across workers according to sharding, i.e.
__get_item__
is only called for indices within the shards assigned to the current worker (and rank), while indices form one shard are all handled by exactly one worker (and rank).
My thoughts and attempted solutions
- I can build a custom distributed sampler that only draws from a subset of indices that is chosen based on the current rank. From my understanding, there is one sampler per rank.
- I can use a custom
worker_init_fn
to tell each worker’s dataset, which shards to load and which indices to handle. So when__get_item__
is called, an index from outside this worker’s range would be ignored/returnNone
. Since workers are assigned disjoint subsets by default (correct?), no other worker will handle those ignored indices either. Thus, some item would be missing in my resulting batch.
Now I found this thread related to problem 2, but I don’t really get the solution they are suggesting. It says
Each worker should have an independent sampler that generates all the indices, but you can decide what to do with them at each worker.
From my understanding there is exactly one sampler (per rank), namely in the main thread, and this sampler always outputs all the indices for one batch at a time. Then, inside dataloader, those indices are assigned to the workers (in round robin I think?), but this behavior is not customisable, is it?
I’d like to have a hook here and tell the dataloader, which indices to assign to which worker. Or alternatively something to assign all the indices of a batch to all the workers (and within the worker ai decide which of them I want to handle or ignore)
Or is this an entirely bad idea and I should prefer a completely different approach?