Looking at torch.nn.parallel.DistributedDataParallel() documentation, it states that "This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. "
Does that mean I can call torch.utils.data.DataLoader() to get a trainset in the context of the main process (i.e. before calling mp.spawn() ), spawn child processes (ranks) and when I do:
for batch_idx, (inputs, targets) in enumerate(trainset): inputs, targets = inputs.to(rank), targets.to(rank)
DDP will automatically split the trainset to each rank ?
Or should I call DataLoader() within the child process ?