I have seen this and tried to understand the documentation:
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the batch
dimension. The module is replicated on each machine and each device, and
each such replica handles a portion of the input. During the backwards
pass, gradients from each node are averaged.
The batch size should be larger than the number of GPUs used locally. It
should also be an integer multiple of the number of GPUs so that each chunk
is the same size (so that each GPU processes the same number of samples).
Does that mean the entire batch of data is first loaded in rank 0, then chunks are sent to other ranks? or each rank will read only the part of the batch that is in its chunk portion?
I do not know how that would be possible (for each worker to read only part of the batch) unless there is a requirement for data loader to subclass
torch.utils.data.Dataset that has
__getitem__. If for example data comes from a custom class, how would
DataParallel know how to read the proper chunk? and if the entire batch is read each time, it could be inefficient to send huge amount of say images among nodes, where each node could just read the images from a shared directory.
Looking at the code it does look like
distributed scatters the input! this is very inefficient when inputs are images for example. It would be much more efficient for each process to read its portion from a shared directory/file.
I think I would subclass
DistributedDataParallel and do just that.
How is the performance of distributed pytorch with large inputs on multi-gpu (say 32+ nodes)?
@dashesy You can use
torch.utils.data.distributed.DistributedSampler (docs) to prevent input data broadcasting among nodes. You can see it in action in the distributed ImageNet example
A few more observations:
I am trying to understand how
DistributedSampler prevents scattering inputs among nodes when using
DistributedDataParallel. According to this answer, imagenet example uses
torch.utils.data.distributed.DistributedSampler to do that.
# compute output
output = model(input)
however, calls forward on
DistributedDataParallel which does call scatter:
def forward(self, *inputs, **kwargs):
self.need_reduction = True
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
So even though
DistributedSampler has the data already chunked up for each replica, the batches will be still scattered, but they will be scattered among local devices, not among nodes! It took me some time to realize that
device_ids are all local devices for each node and
DistributedDataParallel does not scatter among nodes.