I am wondering about the recommended approach to balancing dataset sizes across different devices while training with DDP. I have split my dataset across four GPUs, but one of them receives a single extra batch, which causes training to hang and wait indefinitely for gradient synchronization with the other devices. I have thought of a few fixes but each seems like it has a drawback:
1.Throw out the final batch to guarantee equal number of iterations
2. Use torch.cuda.no_sync() decorator on the final batch. This will cause one device to have different model weights.
3. Proceed to the next epoch on the other devices and allow the first batch of epoch 2 to synchronize with this final batch from epoch 1.
I appreciate any suggestions you can give!