Batch chunking in DataParallel and DistributedDataParallel

I have seen this and tried to understand the documentation:

    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the batch
    dimension. The module is replicated on each machine and each device, and
    each such replica handles a portion of the input. During the backwards
    pass, gradients from each node are averaged.
    The batch size should be larger than the number of GPUs used locally. It
    should also be an integer multiple of the number of GPUs so that each chunk
    is the same size (so that each GPU processes the same number of samples).

Does that mean the entire batch of data is first loaded in rank 0, then chunks are sent to other ranks? or each rank will read only the part of the batch that is in its chunk portion?
I do not know how that would be possible (for each worker to read only part of the batch) unless there is a requirement for data loader to subclass torch.utils.data.Dataset that has __getitem__. If for example data comes from a custom class, how would DataParallel know how to read the proper chunk? and if the entire batch is read each time, it could be inefficient to send huge amount of say images among nodes, where each node could just read the images from a shared directory.

Looking at the code it does look like distributed scatters the input! this is very inefficient when inputs are images for example. It would be much more efficient for each process to read its portion from a shared directory/file.
I think I would subclass DistributedDataParallel and do just that.

How is the performance of distributed pytorch with large inputs on multi-gpu (say 32+ nodes)?

@dashesy You can use torch.utils.data.distributed.DistributedSampler (docs) to prevent input data broadcasting among nodes. You can see it in action in the distributed ImageNet example

1 Like

A few more observations:

I am trying to understand how DistributedSampler prevents scattering inputs among nodes when using DistributedDataParallel. According to this answer, imagenet example uses torch.utils.data.distributed.DistributedSampler to do that.

This line:

       # compute output
        output = model(input)

however, calls forward on DistributedDataParallel which does call scatter:

    def forward(self, *inputs, **kwargs):
        self.need_reduction = True
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)

So even though DistributedSampler has the data already chunked up for each replica, the batches will be still scattered, but they will be scattered among local devices, not among nodes! It took me some time to realize that device_ids are all local devices for each node and DistributedDataParallel does not scatter among nodes.