Hi, I’m trying to train a modded-nanogpt on PrimeIntellect, I’m getting weird latency spikes once every few hundred training steps (not always the same training step unfortunately): a regular training step take ~60ms, one with a latency spike takes ~400ms. There is no error on the pytorch side to indicate what’s the issue, but it also negatively influences the validation loss of the training run. I ran one run with NCCL logging and noticed this bunch of logs correlated with the lagging training step:
1f89df06ce99:50400:50560 [6] NCCL INFO Received and initiated operation=Close res=011f89df06ce99:50400:50560 [6] NCCL INFO Received and initiated operation=Close res=011f89df06ce99:50400:50560 [6] NCCL INFO Received and initiated operation=Close res=0
1f89df06ce99:50400:50560 [6] NCCL INFO Received and initiated operation=Stop res=0
1f89df06ce99:50399:50553 [5] NCCL INFO [Proxy Service UDS] exit: stop 1 abortFlag 0
1f89df06ce99:50399:50399 [5] NCCL INFO NVLS Unbind MC handle 7fe5a3d7f7d0 size 2097152 dev 5
1f89df06ce99:50399:50399 [5] NCCL INFO NVLS Unmap mem UC handle 0x7fe5a3d7fd80(0x7fe5b7c00000) ucsize 2097152 MC handle 0x7fe5a3d7f7d0(0x7fe5b7e00000) mcsize 2097152
1f89df06ce99:50399:50399 [5] NCCL INFO comm 0x56333fbb8930 rank 5 nranks 8 cudaDev 5 busId bb000 - Destroy COMPLETE
1f89df01f89df06ce99:50396:50550 [0] NCCL INFO RAS current socket connection with 172.17.0.3<47475> closed by peer on receive; terminat1f89df01f89df06ce99:50396:50550 [0] NCCL INFO RAS handling local termination request
1f89df06ce99:50396:50550 [0] NCCL INFO RAS termin1f89df06ce99:50399:50545 [0] NCCL INFO RAS handling local termination request
1f89df06ce99:50399:50545 [0] NCCL INFO RAS terminating a connection with 172.17.0.3<60893>
1f89df06ce99:50399:50545 [0] NCCL INFO RAS terminating a connection with 172.17.0.3<52531>
1f89df06ce99:50399:50545 [0] NCCL INFO RAS thread terminating
ing
h 172.17.0.3<60893> closed by peer on receive; terminating it
1f89df06ce99:50401:50544 [0] NCCL INFO RAS current socket connection with 172.17.0.3<39245> closed by peer on receive; terminating it
1f89df06ce99:50401:50544 [0] NCCL INFO RAS handling local termination request
1f89df06ce99:50401:50544 [0] NCCL INFO RAS terminating a connection with 172.17.0.3<39245>
1f89df06ce99:50401:50544 [0] NCCL INFO RAS terminating a connection with 172.17.0.3<60893>
1f89df06ce99:50401:50544 [0] NCCL INFO RAS thread terminating
7f81d3c00000) ucsize 2097152 MC handle 0x7f81bfd7f7d0(0x7f81d3e00000) mcsize 2097152
1f89df06ce99:50395:50395 [1] NCCL INFO comm 0x55f252944d40 rank 1 nranks 8 cudaDev 1 busId 3b000 - Destroy COMPLETE
1f89df06ce99:50395:50547 [0] NCCL INFO RAS handling local termination request
1f89df06ce99:50395:50547 [0] NCCL INFO RAS terminating a connection with 172.17.0.3<49479>
1f89df06ce99:50395:50547 [0] NCCL INFO RAS terminating a connection with 172.17.0.3<39245>
1f89df06ce99:50395:50547 [0] NCCL INFO RAS thread terminating
..23}
The logs are surrounded by regular AllGather and ReduceScatter messages.
The setup is using eth0 for bootstrap (not SHM), and P2P as far as I can tell from reading the logs (I’d be happy to share the setup logs as well, seems like this is getting a bit long of a post).
The bad steps seem to happen rather consistently per machine I rent (±5 steps of variance usually) but in different places in different machines.
Are there any logs I should enable from the Python side that might help explain what’s going on?
I ran with TORCH_DISTRIBUTED_DEBUG=DETAIL and had some issues with stderr redirection. I have some partial logs from the run and the problematic step has this to say:
The unusual part starts with check key_count:1 keys[0]:/exception_dump address (note the gaps in timestamps). It takes roughly 340ms to recover, but I don’t know from what…
I’ve been looking into this more and I have a problematic run with all of the distributed portions of the code working fine. Only rank 0 is delayed by 400ms and then blocks ranks 1-7 on the next distributed barrier…
Hm thats quite interesting. So is it a correct understanding that rank 0 is a straggler and causing other ranks to be delayed? Have you looked into why rank 0 is behaving so slowly? I recommend taking a profiler trace and seeing where the delays come from: PyTorch Profiler — PyTorch Tutorials 2.9.0+cu128 documentation