Using IterableDataset with DistributedDataParallel

I’m building an NLP application that with a dataloader that builds batches out of sequential blocks of text in a file. I have been using an IterableDataset since my text file won’t fit into memory. However, when I use with with DistributedDataParallel, the dataloader is replicated across processes and each GPU ends up with the same batch of data. How can I give each GPU a different batch of data to take advantage of distributed training?

Note: My dataloader can load ~300 batches/second on a single and each GPU takes ~2 seconds to process a batch, so dataloader speed should not be a significant limiting factor even if batches were sent in serial to different GPUs

1 Like

How did you verify that all the gpus are using the same batch. Thats not how DDP works, it will take different chunks automatically.

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and each device, and each such replica handles a portion of the input. During the backwards pass, gradients from each node are averaged.

I met a similar problem recently, and I think the batches should be the same across different GPUs according to the source code.
If you look at the function DistributedSampler which we use in DDP, the chunking function is done by this class. However, if you look at the source code of Dataloader, sampler will not affect the behavior of data fetching of iterable datasets.
see line 34 in

1 Like

@bsridatta I verified that my data was being replicated across batches by actually printing them out. I’m not sure, but this problem may be a product of using pytorch-lightning, which makes a copy of the dataloader for each GPU.

In any case, I was able to fix the problem by creating an array of pointers to the start of each training example in my file using an approach similar to the one used here. This allowed me to quickly sample random training contexts from my text file without needing to read it into memory.

Hello @kartch, thanks a lot for the explaining the workaround. You maybe right about pytorch-lightning, had few crazy issues, some of the backdraws of abstraction I guess.

@VitalyFedyunin @SimonW I’m wondering if we officially support using DistributedSampler with IterableDataset?

1 Like

In your dataset class, you can take in a shard ID to shard the dataset properly. Then using the distributed training rank as the shard ID should work.

1 Like

I had a similar use case and ended up implementing an IterableDataset that handles both multiprocessing via DataLoader and distributed training.

I shared my code here.

But I have to say that I turned away from that approach in the end, as it is relatively inflexible and turned the IterableDataset simply into an indexable one.

But how did you overcome the memory problem? If have 40m examples in my training set, it is not feasible to load all the examples in each process… Did you use HDF5 or something similar?