NCCL Timeout only on H100s, not other hardware

I am running a 2 GPU (same node) training run. The script works:

  1. On two A40s
  2. On one H100

But fails very interestingly on 2 H100s. Not sure the source of the hardware dependance. The script successfully runs a pretraining evaluation, a bunch of training steps, a second evaluation, and then cannot collect gradients when training resumes. I am using the accelerate library. The trace, as well as my scripts and some configs are below.

accelerate launch pipeline/2.1_self_supervised_training.py
/kfs2/projects/metalsitenn/metal_site_modeling/equiformer/nets/layer_norm.py:89: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/kfs2/projects/metalsitenn/metal_site_modeling/equiformer/nets/layer_norm.py:89: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
x3100c0s5b0n0:3347432:3347432 [0] NCCL INFO Bootstrap : Using hsn0:10.150.3.12<0>
x3100c0s5b0n0:3347432:3347432 [0] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
x3100c0s5b0n0:3347433:3347433 [1] NCCL INFO cudaDriverVersion 12040
x3100c0s5b0n0:3347433:3347433 [1] NCCL INFO Bootstrap : Using hsn0:10.150.3.12<0>
x3100c0s5b0n0:3347433:3347433 [1] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
x3100c0s5b0n0:3347432:3347432 [0] NCCL INFO cudaDriverVersion 12040
NCCL version 2.20.5+cuda12.4
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO NET/IB : No device found.
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO NET/IB : No device found.
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO NET/Socket : Using [0]hsn0:10.150.3.12<0> [1]hsn1:10.150.1.122<0> [2]bond0:172.23.1.3<0>
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Using non-device net plugin version 0
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Using network Socket
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO NET/Socket : Using [0]hsn0:10.150.3.12<0> [1]hsn1:10.150.1.122<0> [2]bond0:172.23.1.3<0>
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Using non-device net plugin version 0
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Using network Socket
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO comm 0xaf7b270 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 4000 commId 0x9d8f751b9e10c9be - Init START
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO comm 0xc0884c0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 64000 commId 0x9d8f751b9e10c9be - Init START
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Setting affinity for GPU 1 to 01
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO comm 0xaf7b270 rank 0 nRanks 2 nNodes 1 localRanks 2 localRank 0 MNNVL 0
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 00/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 01/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 02/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 03/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 04/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 05/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 06/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 07/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] -1/-1/-1->0->1 [3] -1/-1/-1->0->1 [4] 1/-1/-1->0->-1 [5] 1/-1/-1->0->-1 [6] -1/-1/-1->0->1 [7] -1/-1/-1->0->1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO P2P Chunksize set to 524288
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO comm 0xc0884c0 rank 1 nRanks 2 nNodes 1 localRanks 2 localRank 1 MNNVL 0
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0 [2] 0/-1/-1->1->-1 [3] 0/-1/-1->1->-1 [4] -1/-1/-1->1->0 [5] -1/-1/-1->1->0 [6] 0/-1/-1->1->-1 [7] 0/-1/-1->1->-1
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO P2P Chunksize set to 524288
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 04/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 05/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 06/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 07/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 04/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 05/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 06/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 07/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Connected all rings
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Connected all trees
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Connected all rings
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Connected all trees
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO 8 coll channels, 0 collnet channels, 0 nvls channels, 8 p2p channels, 8 p2p channels per peer
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO 8 coll channels, 0 collnet channels, 0 nvls channels, 8 p2p channels, 8 p2p channels per peer
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO comm 0xc0884c0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 64000 commId 0x9d8f751b9e10c9be - Init COMPLETE
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO comm 0xaf7b270 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 4000 commId 0x9d8f751b9e10c9be - Init COMPLETE
/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.
  warnings.warn(
 20%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                                  | 4/20 [00:08<00:41,  2.60s/it]/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/dvc_render/vega.py:169: UserWarning: `generate_markdown` can only be used with `LinearTemplate`
  warn("`generate_markdown` can only be used with `LinearTemplate`")  # noqa: B028
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž                                                                              | 9/20 [00:18<00:13,  1.20s/it]/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/dvc_render/vega.py:169: UserWarning: `generate_markdown` can only be used with `LinearTemplate`
  warn("`generate_markdown` can only be used with `LinearTemplate`")  # noqa: B028
/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.
  warnings.warn(
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                                       | 10/20 [00:32<00:51,  5.10s/it]

Eventually…

Rank 1] Timeout at NCCL work: 456, last enqueued NCCL work: 456, last completed NCCL work: 455.
[rank1]:[E205 14:27:18.083900057 ProcessGroupNCCL.cpp:621] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E205 14:27:18.083904965 ProcessGroupNCCL.cpp:627] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E205 14:27:19.451747711 ProcessGroupNCCL.cpp:1515] [PG 0 (default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=456, OpType=ALLREDUCE, NumelIn=65297, NumelOut=65297, Timeout(ms)=600000) ran for 600054 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:609 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64e3b3af86 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f64e4e378d2 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f64e4e3e313 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f64e4e406fc in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xd3b65 (0x7f65325d6b65 in /kfs2/projects/proteinml/.links/miniconda3/envs/metal/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x81ca (0x7f653423f1ca in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x7f6533721e73 in /lib64/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 0 (default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=456, OpType=ALLREDUCE, NumelIn=65297, NumelOut=65297, Timeout(ms)=600000) ran for 600054 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:609 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64e3b3af86 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f64e4e378d2 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f64e4e3e313 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f64e4e406fc in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xd3b65 (0x7f65325d6b65 in /kfs2/projects/proteinml/.links/miniconda3/envs/metal/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x81ca (0x7f653423f1ca in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x7f6533721e73 in /lib64/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64e3b3af86 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f64e4ac9a84 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xd3b65 (0x7f65325d6b65 in /kfs2/projects/proteinml/.links/miniconda3/envs/metal/bin/../lib/libstdc++.so.6)
frame #3: <unknown function> + 0x81ca (0x7f653423f1ca in /lib64/libpthread.so.0)
frame #4: clone + 0x43 (0x7f6533721e73 in /lib64/libc.so.6)

W0205 14:27:28.725000 140287485417280 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 1075899 closing signal SIGTERM
E0205 14:27:28.964000 140287485417280 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: -6) local_rank: 1 (pid: 1075900) of binary: /projects/proteinml/.links/miniconda3/envs/metal/bin/python3.10
Traceback (most recent call last):
  File "/projects/proteinml/.links/miniconda3/envs/metal/bin/accelerate", line 10, in <module>
    sys.exit(main())
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1163, in launch_command
    multi_gpu_launcher(args)
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/accelerate/commands/launch.py", line 792, in multi_gpu_launcher
    distrib_run.run(args)
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
========================================================
pipeline/2.1_self_supervised_training.py FAILED
--------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-02-05_14:27:28
  host      : x3100c0s5b0n0.head.cm.kestrel.hpc.nrel.gov
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 1075900)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 1075900
========================================================

The trainer I wrote and am calling:
(also in a comment)

Accelerate config:

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

My environment (loading cuda 12.1 from elsewhere on my cluster):
(in a comment because I am hitting max length.)

NCCL tests:

./build/all_reduce_perf -b 8 -e 128M -f 2 -g 2
# nThread 1 nGpus 2 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
#  Rank  0 Group  0 Pid 1512849 on x3100c0s5b0n0 device  0 [0x04] NVIDIA H100 80GB HBM3
#  Rank  1 Group  0 Pid 1512849 on x3100c0s5b0n0 device  1 [0x64] NVIDIA H100 80GB HBM3
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
           8             2     float     sum      -1     7.09    0.00    0.00      0     7.33    0.00    0.00      0
          16             4     float     sum      -1     7.40    0.00    0.00      0     7.36    0.00    0.00      0
          32             8     float     sum      -1     7.32    0.00    0.00      0     7.29    0.00    0.00      0
          64            16     float     sum      -1     7.55    0.01    0.01      0     7.32    0.01    0.01      0
         128            32     float     sum      -1     7.40    0.02    0.02      0     7.34    0.02    0.02      0
         256            64     float     sum      -1     7.47    0.03    0.03      0     7.35    0.03    0.03      0
         512           128     float     sum      -1     7.31    0.07    0.07      0     7.26    0.07    0.07      0
        1024           256     float     sum      -1     7.77    0.13    0.13      0     7.56    0.14    0.14      0
        2048           512     float     sum      -1     7.80    0.26    0.26      0     7.69    0.27    0.27      0
        4096          1024     float     sum      -1     8.03    0.51    0.51      0     7.80    0.53    0.53      0
        8192          2048     float     sum      -1     8.36    0.98    0.98      0     8.13    1.01    1.01      0
       16384          4096     float     sum      -1     8.55    1.92    1.92      0     8.26    1.98    1.98      0
       32768          8192     float     sum      -1     8.65    3.79    3.79      0     8.51    3.85    3.85      0
       65536         16384     float     sum      -1     9.02    7.26    7.26      0     8.41    7.79    7.79      0
      131072         32768     float     sum      -1    10.14   12.92   12.92      0     9.77   13.41   13.41      0
      262144         65536     float     sum      -1    12.83   20.43   20.43      0    11.84   22.15   22.15      0
      524288        131072     float     sum      -1    24.62   21.30   21.30      0    25.45   20.60   20.60      0
     1048576        262144     float     sum      -1    28.37   36.96   36.96      0    28.29   37.07   37.07      0
     2097152        524288     float     sum      -1    36.02   58.23   58.23      0    36.02   58.23   58.23      0
     4194304       1048576     float     sum      -1    52.13   80.47   80.47      0    51.97   80.71   80.71      0
     8388608       2097152     float     sum      -1    86.74   96.71   96.71      0    86.53   96.95   96.95      0
    16777216       4194304     float     sum      -1    158.8  105.64  105.64      0    155.5  107.88  107.88      0
    33554432       8388608     float     sum      -1    297.4  112.83  112.83      0    298.1  112.57  112.57      0
    67108864      16777216     float     sum      -1    570.2  117.69  117.69      0    569.1  117.93  117.93      0
   134217728      33554432     float     sum      -1   1109.8  120.94  120.94      0   1107.4  121.20  121.20      0
# Out of bounds values : 0 OK
# A```

Much appreciation for any guidance, been beating my head on this one for a good day and only have so much time allocated to this project.

My env:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
absl-py                   2.1.0                    pypi_0    pypi
accelerate                1.3.0                    pypi_0    pypi
aiohappyeyeballs          2.4.4                    pypi_0    pypi
aiohttp                   3.11.12                  pypi_0    pypi
aiohttp-retry             2.9.1                    pypi_0    pypi
aiosignal                 1.3.2                    pypi_0    pypi
amqp                      5.3.1                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
appdirs                   1.4.4                    pypi_0    pypi
ase                       3.24.0                   pypi_0    pypi
async-timeout             5.0.1                    pypi_0    pypi
asyncssh                  2.19.0                   pypi_0    pypi
atpublic                  5.1                      pypi_0    pypi
attrs                     25.1.0                   pypi_0    pypi
billiard                  4.2.1                    pypi_0    pypi
biopandas                 0.5.1                    pypi_0    pypi
biopython                 1.85                     pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.12.31           h06a4308_0  
celery                    5.4.0                    pypi_0    pypi
certifi                   2025.1.31                pypi_0    pypi
cffi                      1.17.1                   pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
click                     8.1.8                    pypi_0    pypi
click-didyoumean          0.3.1                    pypi_0    pypi
click-plugins             1.1.1                    pypi_0    pypi
click-repl                0.3.0                    pypi_0    pypi
cloudpickle               3.1.1                    pypi_0    pypi
colorama                  0.4.6                    pypi_0    pypi
configobj                 5.0.9                    pypi_0    pypi
contourpy                 1.3.1                    pypi_0    pypi
cryptography              44.0.0                   pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  3.2.0                    pypi_0    pypi
dictdiffer                0.9.0                    pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
diskcache                 5.6.3                    pypi_0    pypi
distro                    1.9.0                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
dpath                     2.2.0                    pypi_0    pypi
dulwich                   0.22.7                   pypi_0    pypi
dvc                       3.59.0                   pypi_0    pypi
dvc-data                  3.16.9                   pypi_0    pypi
dvc-http                  2.32.0                   pypi_0    pypi
dvc-objects               5.1.0                    pypi_0    pypi
dvc-render                1.0.2                    pypi_0    pypi
dvc-studio-client         0.21.0                   pypi_0    pypi
dvc-task                  0.40.2                   pypi_0    pypi
dvclive                   3.48.1                   pypi_0    pypi
e3nn                      0.5.5                    pypi_0    pypi
entrypoints               0.4                      pypi_0    pypi
fairchem-core             1.4.0                    pypi_0    pypi
filelock                  3.17.0                   pypi_0    pypi
flatten-dict              0.4.2                    pypi_0    pypi
flufl-lock                8.1.0                    pypi_0    pypi
fonttools                 4.55.8                   pypi_0    pypi
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.9.0                 pypi_0    pypi
funcy                     2.0                      pypi_0    pypi
gitdb                     4.0.12                   pypi_0    pypi
gitpython                 3.1.44                   pypi_0    pypi
grandalf                  0.8                      pypi_0    pypi
grpcio                    1.70.0                   pypi_0    pypi
gto                       1.7.2                    pypi_0    pypi
huggingface-hub           0.28.1                   pypi_0    pypi
hydra-core                1.3.2                    pypi_0    pypi
idna                      3.10                     pypi_0    pypi
iterative-telemetry       0.0.9                    pypi_0    pypi
jinja2                    3.1.5                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
kiwisolver                1.4.8                    pypi_0    pypi
kombu                     5.4.2                    pypi_0    pypi
latexcodec                3.0.0                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
llvmlite                  0.44.0                   pypi_0    pypi
lmdb                      1.6.2                    pypi_0    pypi
looseversion              1.1.2                    pypi_0    pypi
markdown                  3.7                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
matplotlib                3.10.0                   pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
metalsitenn               0.1                      pypi_0    pypi
mmtf-python               1.1.3                    pypi_0    pypi
monty                     2025.1.9                 pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.1.0                    pypi_0    pypi
multidict                 6.1.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
narwhals                  1.25.1                   pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.4.2                    pypi_0    pypi
numba                     0.61.0                   pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-ml-py              12.570.86                pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
omegaconf                 2.3.0                    pypi_0    pypi
openssl                   3.0.15               h5eee18b_0  
opt-einsum                3.4.0                    pypi_0    pypi
opt-einsum-fx             0.1.4                    pypi_0    pypi
orjson                    3.10.15                  pypi_0    pypi
packaging                 24.2                     pypi_0    pypi
palettable                3.3.3                    pypi_0    pypi
pandas                    2.2.3                    pypi_0    pypi
pathspec                  0.12.1                   pypi_0    pypi
pillow                    11.1.0                   pypi_0    pypi
pip                       25.0            py310h06a4308_0  
platformdirs              4.3.6                    pypi_0    pypi
plotly                    6.0.0                    pypi_0    pypi
prompt-toolkit            3.0.50                   pypi_0    pypi
propcache                 0.2.1                    pypi_0    pypi
protobuf                  5.29.3                   pypi_0    pypi
psutil                    6.1.1                    pypi_0    pypi
pyarrow                   19.0.0                   pypi_0    pypi
pybtex                    0.24.0                   pypi_0    pypi
pycparser                 2.22                     pypi_0    pypi
pydantic                  2.10.6                   pypi_0    pypi
pydantic-core             2.27.2                   pypi_0    pypi
pydot                     3.0.4                    pypi_0    pypi
pygit2                    1.17.0                   pypi_0    pypi
pygments                  2.19.1                   pypi_0    pypi
pygtrie                   2.5.0                    pypi_0    pypi
pymatgen                  2025.1.24                pypi_0    pypi
pynvml                    12.0.0                   pypi_0    pypi
pyparsing                 3.2.1                    pypi_0    pypi
python                    3.10.16              he870216_1  
python-dateutil           2.9.0.post0              pypi_0    pypi
pytz                      2025.1                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rich                      13.9.4                   pypi_0    pypi
ruamel-yaml               0.18.10                  pypi_0    pypi
ruamel-yaml-clib          0.2.12                   pypi_0    pypi
safetensors               0.5.2                    pypi_0    pypi
scikit-learn              1.6.1                    pypi_0    pypi
scipy                     1.15.1                   pypi_0    pypi
scmrepo                   3.3.10                   pypi_0    pypi
seaborn                   0.13.2                   pypi_0    pypi
semver                    3.0.4                    pypi_0    pypi
sentry-sdk                2.20.0                   pypi_0    pypi
setproctitle              1.3.4                    pypi_0    pypi
setuptools                75.8.0          py310h06a4308_0  
shellingham               1.5.4                    pypi_0    pypi
shortuuid                 1.0.13                   pypi_0    pypi
shtab                     1.7.1                    pypi_0    pypi
six                       1.17.0                   pypi_0    pypi
smmap                     5.0.2                    pypi_0    pypi
spglib                    2.5.0                    pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
sqltrie                   0.11.1                   pypi_0    pypi
submitit                  1.5.2                    pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tabulate                  0.9.0                    pypi_0    pypi
tensorboard               2.18.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
threadpoolctl             3.5.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0  
tokenizers                0.21.0                   pypi_0    pypi
tomlkit                   0.13.2                   pypi_0    pypi
torch                     2.4.0                    pypi_0    pypi
torch-cluster             1.6.3+pt24cu121          pypi_0    pypi
torch-geometric           2.6.1                    pypi_0    pypi
torch-scatter             2.1.2+pt24cu121          pypi_0    pypi
torch-sparse              0.6.18                   pypi_0    pypi
torch-spline-conv         1.2.2+pt24cu121          pypi_0    pypi
torchaudio                2.4.0                    pypi_0    pypi
torchvision               0.19.0                   pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.48.2                   pypi_0    pypi
triton                    3.0.0                    pypi_0    pypi
typer                     0.15.1                   pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2025.1                   pypi_0    pypi
uncertainties             3.2.2                    pypi_0    pypi
urllib3                   2.3.0                    pypi_0    pypi
vine                      5.1.0                    pypi_0    pypi
voluptuous                0.15.2                   pypi_0    pypi
wandb                     0.19.6                   pypi_0    pypi
wcwidth                   0.2.13                   pypi_0    pypi
werkzeug                  3.1.3                    pypi_0    pypi
wheel                     0.45.1          py310h06a4308_0  
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1  
yarl                      1.18.3                   pypi_0    pypi
zc-lockfile               3.0.post1                pypi_0    pypi
zlib                      1.2.13               h5eee18b_1  

And my trainer:

@dataclass
class EarlyStoppingState:
    """Tracks early stopping state."""
    counter: int = 0
    best_metric: float = float('inf')
    best_step: int = 0
    
    def state_dict(self) -> Dict[str, Any]:
        return asdict(self)
        
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        self.counter = state_dict['counter']
        self.best_metric = state_dict['best_metric']
        self.best_step = state_dict['best_step']

    def step(self, metric: float, current_step: int, min_improvement: float) -> bool:
        """Returns True if should stop."""
        improvement = (self.best_metric - metric) / self.best_metric
        if improvement > min_improvement:
            self.counter = 0
            bad_step =  False
        else:
            bad_step = True
            self.counter += 1
            logger.info(f"Early stopping counter triggered: {self.counter}, best metric: {self.best_metric}, current metric: {metric}, improvement: {improvement}, min improvement: {min_improvement}")
        if metric < self.best_metric:
            self.best_metric = metric
            self.best_step = current_step
        return bad_step

@dataclass
class MetalSiteTrainingArgs:
    """Arguments for training."""
    output_dir: str = field(default="./training_output")
    logging_dir: str = field(default="./logs")
    
    # Training loop
    num_epochs: int = field(default=1)
    per_device_train_batch_size: int = field(default=8) 
    per_device_eval_batch_size: int = field(default=8)
    gradient_accumulation_steps: int = field(default=1)
    dataloader_num_workers: int = field(default=0)
    
    # Optimizer
    learning_rate: float = field(default=5e-5)
    weight_decay: float = field(default=0.0)
    gradient_clipping: float = field(default=1.0)
    warmup_pct: float = field(default=0.1)
    frac_noise_loss: float = field(default=0.5)
    
    # Logging and checkpoints
    eval_steps: int = field(default=None)
    logging_steps: int = field(default=100) 
    load_best_model_at_end: bool = field(default=True)
    
    # Early stopping
    use_early_stopping: bool = field(default=False)
    early_stopping_patience: int = field(default=3)
    early_stopping_improvement_fraction: float = field(default=0.0)

    def __str__(self):
        return str(asdict(self))

class MetalSiteTrainer:
    """Trainer for metal site models with distributed training support.
    
    Args
    ----
    model: nn.Module
        Model to train
    compute_loss_fn: Callable
        Function to compute loss. Signiture should be:
            compute_loss_fn(trainer: MetalSiteTrainer, input_batch: Dict[str, torch.Tensor], return_outputs: bool = False) -> Dict[str, torch.Tensor]
            Must return dict like with at least a 'loss' key.
            During evaluation, this is called with return_outputs=True to return model outputs for metrics.
    args: MetalSiteTrainingArgs
        Training arguments
    train_dataset: Dataset
        Training dataset
    eval_dataset: Dataset
        Evaluation dataset
    data_collator: Callable
        Data collator
    eval_metrics: Optional[Dict[str, Callable]]
        Metrics to compute during evaluation. This is a dict of callable, each with signature: f(outputs) where outputs are the 
        returns of compute_loss_fn. If None, only loss is computed
    hard_eval_metrics: Optional[Dict[str, Callable]]
        Metrics that require additional computation and are not directly returned by compute_loss_fn. These are called seperately with trainer as the only argument.
        Up to you to loop through whatever dataset to compute it.
    """
    
    def __init__(
        self,
        model,
        compute_loss_fn: Callable,
        args: MetalSiteTrainingArgs,
        train_dataset=None,
        eval_dataset=None,
        data_collator=None,
        eval_metrics: Optional[Dict[str, Callable]]=None,
        hard_eval_metrics: Optional[Dict[str, Callable]]=None,
        quit_early: bool = False
    ):
        self.args = args
        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.data_collator = data_collator
        self.compute_loss_fn = compute_loss_fn
        self.eval_metrics = eval_metrics or {}
        
        # Initialize early stopping
        self.early_stopping = EarlyStoppingState() if args.use_early_stopping else None
        
        # Initialize accelerator
        ipgk = InitProcessGroupKwargs(timeout=timedelta(180))
        self.accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            log_with="dvclive",
            project_dir=args.output_dir,
            kwargs_handlers=[ipgk]
        )
        if self.accelerator.is_main_process:
            logger.info(f"Accelerator params: {self.accelerator.__dict__}")
        self.accelerator.init_trackers(project_name="training", init_kwargs={
            "dvclive": {
                "dir": os.path.join(args.output_dir, "dvclive"),
                "report": 'md',
                "save_dvc_exp": False,
                "dvcyaml": None
            }
        })
        
        if self.early_stopping:
            self.accelerator.register_for_checkpointing(self.early_stopping)

        # Create dataloaders
        self.train_dataloader = self._get_train_dataloader() if train_dataset else None
        self.eval_dataloader = self._get_eval_dataloader() if eval_dataset else None

        # Set up optimizer and scheduler   
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=args.learning_rate,
            epochs=args.num_epochs,
            steps_per_epoch=len(self.train_dataloader),
            pct_start=args.warmup_pct
        )

        # Prepare everything with accelerator
        prepared = self.accelerator.prepare(
            self.model,
            self.optimizer, 
            self.train_dataloader,
            self.eval_dataloader,
            self.scheduler
        )
        self.model, self.optimizer, self.train_dataloader, self.eval_dataloader, self.scheduler = prepared

        self.n_warmup_steps = args.warmup_pct * args.num_epochs * len(self.train_dataloader)

        # hard eval metrics
        self.hard_eval_metrics = hard_eval_metrics or {}

        # create checkpointomg folder if not present
        if self.accelerator.is_main_process:
            if not os.path.exists(os.path.join(args.output_dir, "checkpoints")):
                os.makedirs(os.path.join(args.output_dir, "checkpoints"))
        self.quit_early = quit_early
        os.environ["NCCL_DEBUG"] = "INFO"

    def _get_train_dataloader(self) -> DataLoader:
        """Create training dataloader."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            shuffle=True
        )

    def _get_eval_dataloader(self) -> DataLoader:
        """Create evaluation dataloader."""
        return DataLoader(
            self.eval_dataset,
            batch_size=self.args.per_device_eval_batch_size,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers
        )
    
    def save_checkpoint(self, output_dir: str):
        """Save model checkpoint with dynamic parameter handling"""
        # Initialize dynamic params before saving
        dummy_batch = next(iter(self.train_dataloader))
        with torch.no_grad():
            self.model(**dummy_batch)
        
        self.accelerator.save_state(output_dir, safe_serialization=False)

    def load_checkpoint(self, checkpoint_dir: str):
        """Load checkpoint with dynamic parameter handling"""
        # Initialize dynamic params before loading
        dummy_batch = next(iter(self.train_dataloader))
        with torch.no_grad():
            self.model(**dummy_batch)
            
        self.accelerator.load_state(checkpoint_dir)

    def _cleanup_checkpoints(self):
        """Maintain only best checkpoint and last N checkpoints where N=patience."""
        if not self.early_stopping:
            return
            
        checkpoint_dir = os.path.join(self.args.output_dir, "checkpoints")
        checkpoints = sorted([
            int(f.split('_')[-1]) 
            for f in os.listdir(checkpoint_dir) 
            if f.startswith('step_')
        ])
        
        # Always keep best checkpoint
        checkpoints_to_keep = {self.early_stopping.best_step}
        
        # Keep last patience number of checkpoints
        patience_checkpoints = checkpoints[-self.args.early_stopping_patience:]
        checkpoints_to_keep.update(patience_checkpoints)
        
        # Remove others
        for step in checkpoints:
            if step not in checkpoints_to_keep:
                checkpoint_path = os.path.join(checkpoint_dir, f'step_{step}')
                if os.path.exists(checkpoint_path):
                    import shutil
                    shutil.rmtree(checkpoint_path)

    def evaluate(self) -> float:
        """Run evaluation and compute metrics over full dataset."""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        # Initialize metric accumulators for each process
        process_metrics = {name: [] for name in self.eval_metrics.keys()}
        
        for batch in self.eval_dataloader:
            with torch.no_grad():
                outputs = self.compute_loss_fn(self, batch, return_outputs=True)
                loss = outputs["loss"]
                total_loss += loss.detach().float()
                
                # Compute metrics on each process separately
                if self.eval_metrics:
                    for name, func in self.eval_metrics.items():
                        metric_val = func(self, outputs, batch)
                        if metric_val is not None:
                            process_metrics[name].append(metric_val)
                            
            num_batches += 1

        # Gather and average loss across processes
        total_loss = self.accelerator.gather(total_loss).mean()
        num_batches = self.accelerator.gather(torch.tensor(num_batches, device=self.accelerator.device, dtype=torch.float)).mean()
        avg_loss = total_loss / num_batches

        # Average metrics for each process then gather
        metrics = {"eval/loss": avg_loss.cpu().item()}
        if self.eval_metrics:
            for name, values in process_metrics.items():
                if values:  # Only process if we have values
                    process_avg = torch.tensor(sum(values) / len(values), device=self.accelerator.device)
                    gathered_avgs = self.accelerator.gather(process_avg)
                    metrics[f"eval/{name}"] = gathered_avgs.mean().cpu().item()
                else:
                    metrics[f"eval/{name}"] = float('nan')
                    
        self.accelerator.log(metrics, step=self.global_step)

        # Run any hard metrics
        for name, func in self.hard_eval_metrics.items():
            func(self)
        
        self.model.train()
        torch.cuda.empty_cache()
        return avg_loss.item()

    def train(self, resume_from_checkpoint: Optional[str] = None):
        # Add global step tracking
        self.global_step = 0
        if resume_from_checkpoint:
            # Assuming checkpoint contains global step
            self.global_step = int(resume_from_checkpoint.split('_')[-1])
            self.accelerator.load_state(resume_from_checkpoint)
            logger.info(f"Resumed from checkpoint: {resume_from_checkpoint}")

        if self.accelerator.is_main_process:
            logger.info(
                f"Training with {self.accelerator.num_processes} processes on {self.accelerator.device.type}\n"
                f" - output_dir: {self.args.output_dir}\n"
                f" - examples in dataset: {len(self.train_dataset)}\n"
                f" - per device batch size: {self.args.per_device_train_batch_size}\n"
                f" - gradient accumulation steps: {self.args.gradient_accumulation_steps}\n"

                f" - effective batch size: {self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.accelerator.num_processes}\n"
                f" - total epochs: {self.args.num_epochs}\n"
                f" - steps per epoch: {len(self.train_dataloader)}\n"
                f" - total steps: {self.args.num_epochs * len(self.train_dataloader)}\n"
                f" - param updates per epoch: {len(self.train_dataloader) // self.args.gradient_accumulation_steps}\n"
                f" - warmup steps: {self.n_warmup_steps}\n"
                f" - log training loss every {self.args.logging_steps} steps\n"
                f" - eval and checkpoint every {self.args.eval_steps} steps\n"
                f" - total trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}"
            )

        # run eval before training
        if not self.quit_early:
            self.evaluate()
        
        # Training loop
        for epoch in range(self.args.num_epochs):
            self.model.train()
            total_loss = 0
            
            progress_bar = tqdm(
                self.train_dataloader,
                disable=not self.accelerator.is_local_main_process
            )

            for batch in progress_bar:
                with self.accelerator.accumulate(self.model):
                    outputs = self.compute_loss_fn(self, batch)
                    loss = outputs["loss"]
                    
                    self.accelerator.backward(loss)
                    
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(
                            self.model.parameters(),
                            self.args.gradient_clipping
                        )
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                        self.scheduler.step()

                        if self.quit_early:
                            logger.info("Quitting early")
                            return

                    total_loss += loss.detach().float()

                # Increment global step
                self.global_step += 1

                # Log training metrics
                if self.global_step > 0 and self.global_step % self.args.logging_steps == 0:
                    avg_loss = total_loss / self.args.logging_steps
                    self.accelerator.log({
                        "train/loss": avg_loss.item(),
                        "train/epoch": epoch,
                        "train/global_step": self.global_step,
                        "train/learning_rate": self.optimizer.param_groups[0]["lr"]
                    }, step=self.global_step)
                    total_loss = 0

                # Evaluate and checkpoint if needed
                if (
                    self.args.eval_steps 
                    and self.global_step > 0 
                    and self.global_step % self.args.eval_steps == 0
                ):
                    eval_loss = self.evaluate()
                    self.model.train()

                    # Save checkpoint
                    if self.accelerator.is_main_process:
                        output_dir = os.path.join(
                            self.args.output_dir,
                            "checkpoints",
                            f"step_{self.global_step}"
                        )
                        self.save_checkpoint(output_dir)
                        self._cleanup_checkpoints()

                        if self.early_stopping:
                            should_stop = self.early_stopping.step(
                                eval_loss,
                                self.global_step,
                                self.args.early_stopping_improvement_fraction
                            )
                            if (should_stop and 
                                self.early_stopping.counter >= self.args.early_stopping_patience):
                                if self.global_step > self.n_warmup_steps:
                                    logger.info("Early stopping triggered")
                                    self._finish_up()
                                    return
                        
        # Finish up
        if self.accelerator.is_main_process:
            self._finish_up()


    def _finish_up(self):
        output_dir = os.path.join(
            self.args.output_dir,
            "checkpoints",
            f"step_{self.global_step}"
        )
        self.save_checkpoint(output_dir)

        if self.args.load_best_model_at_end and self.early_stopping and self.early_stopping.best_step > 0:
            best_model_path = os.path.join(
                self.args.output_dir,
                "checkpoints",
                f"step_{self.early_stopping.best_step}"
            )
            logger.info(f"Loading best model from step {self.early_stopping.best_step}")
            self.load_checkpoint(best_model_path)

Are you seeing this issue with a newer PyTorch release?
2.4.0+cu121 is quite old and we don’t build and test PyTorch with CUDA 12.1 anymore.

Thanks for the time Piotr, you’re helping a lot of people on these forums - much appreciation.

One of my model dependencies is torch_geometric which seems to communicate support for up to cu12.1, but I will try a newer build and report back.

1 Like