Ddp training and eval question

Hi all,
I’m pretty new to using pytorch. So, please let me know if there’s anything unclear.

I’ve been working on a project to train a BERT4Rec model using DDP.

Since our input data are many rows spanning across many files, I decided to implement a rank and world_size aware iterative dataset that pretty much just distribute yield rows round robin based on rank and world_size, and not using DistributedDataSampler.

My model is wrapped with DDP and works perfectly with training.
For eval, it works well with smaller batch numbers.
However, with larger batch numbers (like 1000 batches * 1024 batch_size)

After it goes through eval loop, it falls out of sync before I could do call all_gather to consoladate the results. I had a barrier after loop ends and what I saw was that a couple ranks passed the barrier first without waiting while the rest kept waiting for them on the barrier. Thie ends of hanging for 30mins before timing out.

There is no synchronization operation inside eval loop, just forward pass and calculating eval metrics for each batch in each rank.

I’ve tried many ways to eliminate possible mistakes:

  • Eliminate data imbalance - i tried drop_last = True and manual max index loop break condition
  • torch.barrier() after the eval loop ends to make sure all ranks finish the loops before all_gather calls
  • Made sure ranks are still in sync after training loop

In the end, the only thing that worked was to unwrap the model to use it for the eval loop with
model = model.module

I searched everywhere to see if this is something that is correct, but I’ve not found much about it.
Would like to hear if anyone has encountered this issue or have some idea whether or not to do it this way.

Thanks!!

1 Like

Hi, if you can get the call stack of there the model is hanging that would give more actionable feedback, you can inspect the process with py-spy

One hypothesis I have is that during the forward pass of eval, DDP is doing a broadcast for the buffers in your module (DistributedDataParallel — PyTorch 2.7 documentation)-,broadcast_buffers,-(bool)%20%E2%80%93%20Flag). Are you running eval using with torch.no_grad():. You can also try setting broadcast_buffers=False

1 Like

Thanks for the reply.
I’m running on sagemaker training job so won’t be able to run pyspy. I think it has its own profile dashboard but not sure if it would show up since it would just hang.
Eval loop is using torch.no_grad.
I will try running it with broadcast_buffers=False and report back. Thanks for pointing it out, I didn’t know forward pass still requires some ddp communication.