Torch.compile + DDP Multi-Node: grad_norm becomes NaN starting from Epoch 2

Environment

  • PyTorch version: 2.4.0

  • CUDA version: (12.1)

  • Number of nodes: (48 nodes × 8 Nvidia L20)

  • NCCL version: (2.20.5)

  • OS: (ubuntu20.04)

Problem Description

When using torch.compile with DDP in a multi-node setup, grad_norm becomes NaN or Inf starting from the second epoch.

Key observations:

  • :white_check_mark: Single-node + compile: Works fine

  • :white_check_mark: Multi-node + NO compile: Works fine

  • :white_check_mark: Multi-node + compile + Epoch 1: Works fine (identical to eager mode)

  • :cross_mark: Multi-node + compile + Epoch 2+: grad_norm = NaN

This suggests the issue is related to lazy compilation timing or cross-epoch state changes rather than the computation logic itself.

Loss Comparison (Compile vs No-Compile)

Loss Component No Compile Compile Diff Notes
grad_norm 68.84 NaN - :collision: Core issue
smooth_loss 133.97 99.73 34.24 Large diff, involves 2nd derivatives
vel 75.10 55.69 19.41 Velocity prediction
acc 93.80 75.42 18.38 Acceleration prediction
lat_bound 7.31 23.60 16.29 :warning: 3x increase, boundary constraint
traj 77.96 62.01 15.95 Trajectory loss
total_loss 17.19 15.33 1.86 Loss looks normal, but grad is NaN

Note: Forward loss values already diverge significantly between compiled and eager modes, indicating numerical differences in the generated kernels.
:warning: Apology in Advance

I’m sorry that I cannot provide a minimal reproduction script. This issue occurs in a large-scale internal project with proprietary code that I’m not able to share publicly.

However, I’ve tried my best to document all the symptoms, observations, and patterns I’ve found. I’m hoping someone might have encountered similar issues or could point me in the right direction based on these details.

I truly appreciate any help or insights! :folded_hands: