DDP num_workers>0 CUDA error

Hello,

I installed pytorch on the local machine using pip, following instructions in Start Locally | PyTorch, and ran DDP training with pytorch lightning using 4 RTX 2080 ti without NVLINK. But DDP training stalled after several epochs, throwing an error:

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f39b6779617 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f39b673498d in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f39b68349f8 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::cuda::ExchangeDevice(int) + 0x8a (0x7f39b6834e5a in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0xe23e12 (0x7f39b768ae12 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x510c06 (0x7f39f9569c06 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x55ca7 (0x7f39b675eca7 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #7: c10::TensorImpl::~TensorImpl() + 0x1e3 (0x7f39b6756cb3 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #8: c10::TensorImpl::~TensorImpl() + 0x9 (0x7f39b6756e49 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #9: <unknown function> + 0x7c0b88 (0x7f39f9819b88 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
frame #10: THPVariable_subclass_dealloc(_object*) + 0x305 (0x7f39f9819f15 in /home/research/.pyenv/versions/ssl/lib/python3.10/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>

This error was intermittent, so it was difficult to pin down what causes it. But what i found was the error is completely gone when i set num_workers=0 for dataloaders or just use one gpu, though it is not a viable option as i have a lot of images to train with. I couldn’t find a better solution, so i tried running the program using docker image, Docker, which completely got rid of the error even for num_workers>0.

I found that the difference is that i use pip to install pytorch in local env and docker uses conda, but i don’t think that’s the culprit. I’m suspecting CUDA runtime version mismatch with CUDA used to compile torch library, but i believe that would cause CUDA error for one gpu, which was not the case. Could anyone help me find the exact reason for this weird behavior?

This is the output from collect_env in local and container.

# local
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.10.13 (main, Oct 27 2023, 15:41:02) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-87-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti
GPU 2: NVIDIA GeForce RTX 2080 Ti
GPU 3: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 545.23.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             40
On-line CPU(s) list:                0-39
Thread(s) per core:                 2
Core(s) per socket:                 10
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              79
Model name:                         Intel(R) Xeon(R) CPU E5-2640 v4 @ 2.40GHz
Stepping:                           1
CPU MHz:                            1200.000
CPU max MHz:                        3400.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           4789.49
Virtualization:                     VT-x
L1d cache:                          640 KiB
L1i cache:                          640 KiB
L2 cache:                           5 MiB
L3 cache:                           50 MiB
NUMA node0 CPU(s):                  0-9,20-29
NUMA node1 CPU(s):                  10-19,30-39
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT vulnerable
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] pytorch-lightning==2.1.0
[pip3] torch==2.1.0
[pip3] torch-tb-profiler==0.4.3
[pip3] torchmetrics==1.2.0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] Could not collect
# container
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-87-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti
GPU 2: NVIDIA GeForce RTX 2080 Ti
GPU 3: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 545.23.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             40
On-line CPU(s) list:                0-39
Thread(s) per core:                 2
Core(s) per socket:                 10
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              79
Model name:                         Intel(R) Xeon(R) CPU E5-2640 v4 @ 2.40GHz
Stepping:                           1
CPU MHz:                            1200.000
CPU max MHz:                        3400.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           4789.49
Virtualization:                     VT-x
L1d cache:                          640 KiB
L1i cache:                          640 KiB
L2 cache:                           5 MiB
L3 cache:                           50 MiB
NUMA node0 CPU(s):                  0-9,20-29
NUMA node1 CPU(s):                  10-19,30-39
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT vulnerable
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] numpy==1.26.0
[pip3] pytorch-lightning==2.1.0
[pip3] torch==2.1.0
[pip3] torchaudio==2.1.0
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==1.2.0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46343  
[conda] mkl-service               2.4.0           py310h5eee18b_1  
[conda] mkl_fft                   1.3.8           py310h5eee18b_0  
[conda] mkl_random                1.2.4           py310hdb19cb5_0  
[conda] numpy                     1.26.0          py310h5f9d8c6_0  
[conda] numpy-base                1.26.0          py310hb5e798b_0  
[conda] pytorch                   2.1.0           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-lightning         2.1.0                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.1.0               py310_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchmetrics              1.2.0                    pypi_0    pypi
[conda] torchtriton               2.1.0                     py310    pytorch
[conda] torchvision               0.16.0              py310_cu121    pytorch
1 Like

So the interaction between CUDA and dataloaders is known to be tricky (enough for @colesbury to do the NoGIL-Python, apparently). My standard solution is to try to move as much preprocessing as I can to the GPU (and also pre-scale images if you don’t have to) in an effort to not need several workers in the dataloader).

Best regards

Thomas

1 Like

I am also experiencing a similar issue. For me, the error is not intermittent, my training pipeline reliably crashes after about 1 hour of training (1 complete training epoch, 1 unfinished validation epoch). As in the OP, turning off multiprocess data loading fixes the issue. I am running on 8 A100 GPU DDP workers, spread across 4 hosts, 2 GPUs each. The error occurs when I set the number of dataloader workers to 4. Pytorch and other librares were installed with conda. My environment and stacktrace are pasted below.

To clarify, the error I am getting in the DDP model worker says that one of the data loader workers has been aborted. There are no errors in the python logs of the data loader process, just an stderr.

A few questions:

  1. The dataloader stacktrace frame goes through some pickling code. I understood from the docs, that on Linux the default mechanism for sharing data between the data loaders and the model is /dev/shm, not pickling. Where is the best place to start looking to figure out where the pickling is coming from?
  2. I am running in an on-premise SLURM cluster, using Lightning to manage the training loop. CUDA version 12.2 is installed on the cluster. AFAIK, there is no pytorch distro linked against CUDA 12.2, only 11.8, 12.1 and 12.4. I have tried 11.8 and 12.1 and got this problem. Would it make sense to try out 12.4 as well? Or should I be experimenting with something fancier, like docker/singularity containers?
Here is the model stacktrace:
ERROR 03/Aug/2024 05:12:45.018 [2976263:MainThread] <pp_ai_cli> - Top level catch
Traceback (most recent call last):
  File "/hps/nobackup/arl/chembl/evgeny/pp_jobs/pp_14/contrast_1/./code/scripts/ai/pp_ai_cli.py", line 25, in cli_main
    cli = ProteinProductionCli("pp_ai_cli")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/nobackup/arl/chembl/evgeny/pp_jobs/pp_14/contrast_1/code/python/ai/cli.py", line 192, in __init__
    super().__init__(*args, **{**kwargs, 'save_config_kwargs': {"overwrite": True}})
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 394, in __init__
    self._run_subcommand(self.subcommand)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 701, in _run_subcommand
    fn(**fn_kwargs)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
    self.fit_loop.run()
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 141, in run
    self.on_advance_end(data_fetcher)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 295, in on_advance_end
    self.val_loop.run()
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 410, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 640, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1636, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1454, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 633, in wrapped_forward
    out = method(*_args, **_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/nobackup/arl/chembl/evgeny/pp_jobs/pp_14/contrast_1/code/python/ai/model.py", line 89, in validation_step
    self.metrics.update_train(logits, batch)
  File "/hps/nobackup/arl/chembl/evgeny/pp_jobs/pp_14/contrast_1/code/python/ai/metrics.py", line 198, in update_train
    self.update_metric(m, logits, ys)
  File "/hps/nobackup/arl/chembl/evgeny/pp_jobs/pp_14/contrast_1/code/python/ai/metrics.py", line 187, in update_metric
    cls.update_multiclass(m, logits, ys)
  File "/hps/nobackup/arl/chembl/evgeny/pp_jobs/pp_14/contrast_1/code/python/ai/metrics.py", line 176, in update_multiclass
    m.update(probs, ys)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torchmetrics/metric.py", line 492, in wrapped_func
    raise err
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torchmetrics/metric.py", line 482, in wrapped_func
    update(*args, **kwargs)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torchmetrics/classification/precision_recall_curve.py", line 364, in update
    _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 420, in _multiclass_precision_recall_curve_tensor_validation
    num_unique_values = len(torch.unique(target))
                            ^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/_jit_internal.py", line 503, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/_jit_internal.py", line 503, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/functional.py", line 997, in _return_output
    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/functional.py", line 911, in _unique_impl
    output, inverse_indices, counts = torch._unique2(
                                      ^^^^^^^^^^^^^^^
  File "/hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py", line 67, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 2982511) is killed by signal: Aborted. 
Here is the stderr stacktrace, that presumably comes from the dataloader:
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1720538437738/work/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fcd29de7f86 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fcd29d96d10 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fcd29ec3ee8 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1d806 (0x7fcd29e8e806 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x1f763 (0x7fcd29e90763 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x1faa2 (0x7fcd29e90aa2 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #6: <unknown function> + 0x5de610 (0x7fcd8c9cc610 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0x6abdf (0x7fcd29dcbbdf in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #8: c10::TensorImpl::~TensorImpl() + 0x21b (0x7fcd29dc4c3b in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #9: c10::TensorImpl::~TensorImpl() + 0x9 (0x7fcd29dc4de9 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #10: <unknown function> + 0x89a8b8 (0x7fcd8cc888b8 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libtorch_python.so)
frame #11: THPVariable_subclass_dealloc(_object*) + 0x2c6 (0x7fcd8cc88c06 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/site-packages/torch/lib/libtorch_python.so)
frame #12: <unknown function> + 0x1ded92 (0x55e48cc88d92 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #13: <unknown function> + 0x238497 (0x55e48cce2497 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #14: <unknown function> + 0x1ded92 (0x55e48cc88d92 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #15: <unknown function> + 0x3003e8 (0x55e48cdaa3e8 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #16: <unknown function> + 0x2196df (0x55e48ccc36df in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #17: <unknown function> + 0x2196a1 (0x55e48ccc36a1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #18: <unknown function> + 0x2196a1 (0x55e48ccc36a1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #19: <unknown function> + 0x2196a1 (0x55e48ccc36a1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #20: <unknown function> + 0x2196a1 (0x55e48ccc36a1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #21: <unknown function> + 0x2196a1 (0x55e48ccc36a1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #22: <unknown function> + 0x2196a1 (0x55e48ccc36a1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #23: <unknown function> + 0x2180d7 (0x55e48ccc20d7 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #24: <unknown function> + 0x1d7141 (0x55e48cc81141 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #25: <unknown function> + 0x1d4ba2 (0x55e48cc7eba2 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #26: <unknown function> + 0x2a00eb (0x55e48cd4a0eb in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #27: PyFunction_NewWithQualName + 0x3c4 (0x55e48cca44b4 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #28: _PyEval_EvalFrameDefault + 0x3852 (0x55e48cc9af72 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #29: _PyFunction_Vectorcall + 0x181 (0x55e48ccbb4c1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #30: <unknown function> + 0x21886c (0x55e48ccc286c in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #31: _PyObject_MakeTpCall + 0x233 (0x55e48cc8a303 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #32: _PyEval_EvalFrameDefault + 0x716 (0x55e48cc97e36 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #33: _PyFunction_Vectorcall + 0x181 (0x55e48ccbb4c1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #34: PyObject_CallOneArg + 0x52 (0x55e48ccc0602 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #35: <unknown function> + 0xb05e (0x7fccce15d05e in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #36: <unknown function> + 0xa6a8 (0x7fccce15c6a8 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #37: <unknown function> + 0x9356 (0x7fccce15b356 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #38: <unknown function> + 0xb286 (0x7fccce15d286 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #39: <unknown function> + 0x9641 (0x7fccce15b641 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #40: <unknown function> + 0x9356 (0x7fccce15b356 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #41: <unknown function> + 0xb286 (0x7fccce15d286 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #42: <unknown function> + 0x9641 (0x7fccce15b641 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #43: <unknown function> + 0xa239 (0x7fccce15c239 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #44: <unknown function> + 0xa239 (0x7fccce15c239 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #45: <unknown function> + 0xb42d (0x7fccce15d42d in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #46: <unknown function> + 0x9641 (0x7fccce15b641 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #47: <unknown function> + 0x9356 (0x7fccce15b356 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #48: <unknown function> + 0x11e4a (0x7fccce163e4a in /hps/software/users/chembl/evgeny/micromamba/envs/ai/lib/python3.11/lib-dynload/_pickle.cpython-311-x86_64-linux-gnu.so)
frame #49: <unknown function> + 0x219899 (0x55e48ccc3899 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #50: PyObject_Vectorcall + 0x2c (0x55e48cca49cc in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #51: _PyEval_EvalFrameDefault + 0x716 (0x55e48cc97e36 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #52: _PyFunction_Vectorcall + 0x181 (0x55e48ccbb4c1 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #53: _PyEval_EvalFrameDefault + 0x49f9 (0x55e48cc9c119 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #54: <unknown function> + 0x2303a4 (0x55e48ccda3a4 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #55: <unknown function> + 0x22fbe0 (0x55e48ccd9be0 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #56: <unknown function> + 0x302c7f (0x55e48cdacc7f in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #57: <unknown function> + 0x2cfde4 (0x55e48cd79de4 in /hps/software/users/chembl/evgeny/micromamba/envs/ai/bin/python)
frame #58: <unknown function> + 0x817a (0x7fcd96e1317a in /lib64/libpthread.so.0)
frame #59: clone + 0x43 (0x7fcd963b8df3 in /lib64/libc.so.6)

Fatal Python error: Aborted

Here is the output of `collect_env`
$ python -m torch.utils.collect_env
<frozen runpy>:128: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
Collecting environment information...
PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Rocky Linux 8.5 (Green Obsidian) (x86_64)
GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-4)
Clang version: 12.0.1 (Red Hat 12.0.1-4.module+el8.5.0+715+58f51d49)
CMake version: version 3.20.2
Libc version: glibc-2.28

Python version: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-4.18.0-348.20.1.el8_5.x86_64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 535.54.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              48
On-line CPU(s) list: 0-47
Thread(s) per core:  1
Core(s) per socket:  24
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               106
Model name:          Intel(R) Xeon(R) Gold 6342 CPU @ 2.80GHz
Stepping:            6
CPU MHz:             3391.383
CPU max MHz:         3500.0000
CPU min MHz:         800.0000
BogoMIPS:            5600.00
Virtualization:      VT-x
L1d cache:           48K
L1i cache:           32K
L2 cache:            1280K
L3 cache:            36864K
NUMA node0 CPU(s):   0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
NUMA node1 CPU(s):   1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust sgx bmi1 hle avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid sgx_lc fsrm md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.3.3
[pip3] torch==2.4.0
[pip3] torchmetrics==1.4.0.post0
[pip3] triton==3.0.0
[conda] Could not collect
Output of `conda list`
$ conda list | grep torch
  pytorch                        2.4.0            py3.11_cuda12.1_cudnn9.1.0_0  pytorch    
  pytorch-cuda                   12.1             ha16c6d3_5                    pytorch    
  pytorch-lightning              2.3.3            pyhd8ed1ab_0                  conda-forge
  pytorch-mutex                  1.0              cuda                          pytorch    
  torchmetrics                   1.4.0.post0      pyhd8ed1ab_0                  conda-forge
  torchtriton                    3.0.0            py311                         pytorch