DDP device hanging before running torch.dist.all_reduce()

I am training a (relatively) small transformer language model that is about 1.15B params. I am running this training job on my university HPC which uses SLURM. I have tried to adapt my code to the DDP tutorials provided on YouTube here: https://www.youtube.com/watch?v=-K3bZYHYHEA&list=PL_lsbAsL_o2CSuhUhJIiW0IkdT5C2wGWj

My steup:
Python 3.10.4
PyTorch/2.1.2-foss-2022a-CUDA-12.1.1
CUDA/12.1.1
4 Nvidia A100 GPUs w/ 80 GB

Currently, my script runs the eval() function at the start of training to go over the validation dataset and get baseline perplexity metrics. Then I run the first epoch. Then I go back to eval() to see how well the model improved after the first epoch. I noticed that Rank 3 finishes eval first, so it hits the torch.dist.all_reduce() line first. I tried placing dist.barrier() in front of it, but that does not seem to have done anything. I checked the data sizes for each rank. They all process the same number of batches, and each batch has a batch size of 16. The sequence length varies from batch to batch because I did not pad all of them to be the same length (there is a specific reason for this), but all sequences within a batch have the same length. This made me think that Rank 3 had less data to process in terms of number of total tokens, but it doesn’t - it has the second fewest. There is another device that has about 2,000 tokens less that doesn’t even finish running eval before rank 3 does. They all process roughly the same number of tokens per batch at about 15-16 tokens per batch.

Any idea why the job keeps hanging?