DDP: DataLoader and spawn child processes


Looking at torch.nn.parallel.DistributedDataParallel() documentation, it states that "This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. "

Does that mean I can call torch.utils.data.DataLoader() to get a trainset in the context of the main process (i.e. before calling mp.spawn() ), spawn child processes (ranks) and when I do:

for batch_idx, (inputs, targets) in enumerate(trainset):
        inputs, targets = inputs.to(rank), targets.to(rank)

DDP will automatically split the trainset to each rank ?
Or should I call DataLoader() within the child process ?


Hi Brasilino, this seems like an imprecision in our documentation, good catch. I think the comment got copied over from DataParallel, but differently from DataParallel, DistributedDataParallel actually does not further split the input. You are expected to have an instance of the DatasetLoader on each of the ranks, and feed different batches into DistributedDataParallel.

Thanks @aazzolini for clarifying it.