In-balanced number of batches in distributed training

Hey, folks,

Wonder what pytorch’s expected behavior is when data partition is not balanced, then as a result each training process has a drastically different number of batches. As a concrete example, two training machines, each does 10 epoches on its local partition of data, then b/c data is not balanced between partitions, one partition has 100 batches and the other partition has 10 batches, how to make pytorch work in this scenario? Can we have the remaining 90 batches proceed without doing all-reduce?

One idea is probably to let each partition has different batch_size, but for some reason, it is hard to have that information as well.

We have built uneven input support in DDP for this purpose. It is designed to allow users to train DDP models when there are a different # of inputs across ranks (if model has no communication other than those done by DDP), as well as throw an exception that can be caught and recovered from if the model does have custom user communication.

Here are the docs: torch.nn.parallel.distributed — PyTorch 1.9.0 documentation

and tutorial Distributed Training with Uneven Inputs Using the Join Context Manager — PyTorch Tutorials 1.9.0+cu102 documentation. Feel free to follow up with anymore questions!