In DataLoader with DDP and multiple workers, can I control which indices are assigned to which ranks and workers?

Motivation:
I have a large dataset split across multiple shards (separate files) on disk. I can load it into memory. I’d like to distribute training with DDP and use multiple workers in my dataloader. I’d like to have random access, i.e. a map dataset.

This will create world_size * num_workers processes, each trying to fully load the dataset into memory. This will by far exceed memory limits. Hence I’d like each worker to handle only certain shards.

Problem:
How can I

  1. distribute indices across ranks according to sharding, i.e. each rank only gets assigned indices from certain shards, while indices from one shard are all assigned to exactly one rank.
  2. distribute indices across workers according to sharding, i.e. __get_item__ is only called for indices within the shards assigned to the current worker (and rank), while indices form one shard are all handled by exactly one worker (and rank).

My thoughts and attempted solutions

  1. I can build a custom distributed sampler that only draws from a subset of indices that is chosen based on the current rank. From my understanding, there is one sampler per rank.
  2. I can use a custom worker_init_fn to tell each worker’s dataset, which shards to load and which indices to handle. So when __get_item__ is called, an index from outside this worker’s range would be ignored/return None. Since workers are assigned disjoint subsets by default (correct?), no other worker will handle those ignored indices either. Thus, some item would be missing in my resulting batch.

Now I found this thread related to problem 2, but I don’t really get the solution they are suggesting. It says

Each worker should have an independent sampler that generates all the indices, but you can decide what to do with them at each worker.

From my understanding there is exactly one sampler (per rank), namely in the main thread, and this sampler always outputs all the indices for one batch at a time. Then, inside dataloader, those indices are assigned to the workers (in round robin I think?), but this behavior is not customisable, is it?
I’d like to have a hook here and tell the dataloader, which indices to assign to which worker. Or alternatively something to assign all the indices of a batch to all the workers (and within the worker ai decide which of them I want to handle or ignore)

Or is this an entirely bad idea and I should prefer a completely different approach?

I think approach 1 is the correct solution for this. With torch.utils.data.distributed.DistributedSampler each process receives an individual slice of the dataset (or each sampler only samples from a disjoint subset of dataset). There won’t be duplication.
You can also implement your map-style dataset in a way that it loads the individual shards one after the other or loads chunks of each shard at random. Then you won’t have to worry about memory usage, no matter the total size. Just implement __get_item__ accordingly.
All the samplers will sample from the same dataset.

I’m not sure this accomplishes what I need.

I see how I can use the distributed sampler to distribute indices to ranks respecting my existing shards structure.

But 1 and 2 are not alternatives, they should be complementary. So I still wonder how I can have each worker getting assigned only indices from certain shards (a subset of all the indices within the rank) so that the worker has to load only those shards, nothing more, at any time.

Example:

  • My dataset has indices [1,…,100], split across 10 shards with 10 samples each.
  • I use ddp with world size 2.
  • The dataloader shall have num_workers=5.

Then I want to have my indices distributed like this:

|- rank 0: [1, …, 50]
    |- worker 1: loads shard 1, is assigned a random subset of [1, …, 10] at each iteration
     …
    |- worker 5: loads shard 5, is assigned a random subset of [41 …, 50] at each iteration 

|- rank 1: [51,…,100]
    |- worker 1: loads shard 6, is assigned a random subset of [51, …, 60] at each iteration
     …
    |- worker 5: loads shard 10, is assigned a random subset of [91 …, 100] at each iteration