Stucks on 8gpu training setting

Hello,

I am using Pytorch (Pytorch Lightning Framework) to train the Text to Text Transformer model (google/mt5-base at main).

I trained them on 1, 4, 5, 8 gpu environment using DDP.
However, all of 8gpu and 5gpu training attempts, are stuck and failed at a specific point in a specific epoch (54).

This is the last log before stuck, as it seems, its end of an epoch, so I assume that training is stuck due to data loading for next epoch in 8gpu or 5gpu environment.

This issue also occurred regardless of num_worker in DataLoader or different batch_size (32, 16)

Epoch 54: 100%|█████████▉| 2921/2931 [43:38<00:08,  1.12it/s, loss=.., v_num=0]
Epoch 54: 100%|█████████▉| 2925/2931 [43:41<00:05,  1.12it/s, loss=.., v_num=0]
Validating:  99%|█████████▉| 280/282 [02:32<00:01,  1.59it/s]e[A
Epoch 54: 100%|██████████| 2931/2931 [44:01<00:00,  1.11it/s, loss=.., v_num=0]

Any comment or suggestion would be appreciated.

Thank you.

Do you have even number of batches across all GPUs? If not, the training could get stuck since DDP performs a collective sync across all GPUs in the backward pass. If you have uneven batches, you can try using the join API to get around this.

1 Like

@pritamdamania87,

First of all, thank you for your reply.

  • If I understand your comment correctly, It seems I both had training with even and uneven batch size across all GPUs.
  • The log shows 2931 which is the uneven number of batches assigned for a single GPU. however, in another training session, a even number of batches (like 1466) have also experienced the same issue.
  • Could you suggest how to correctly measure the number of batches per GPU? in case I was wrong in measure?
  • I will check to join API try to get around this issue of course.

Thank you.!

  • Could you suggest how to correctly measure the number of batches per GPU? in case I was wrong in measure?

You would probably have to add some logging code to your application which records the number of batches on each rank as part of the training loop.