torch.distributed fails on cluster

Hello

I would like to run torch.distributed on a HPC cluster. The command I’m using is the following:

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node 2 train.py

I’m using two NVIDIA Quadro RTX 6000 GPUs with 24 GB of memory. train.py is a Python script and uses Huggingface Trainer to fine-tune a transformer model.

I’m getting the error shown below. Does somebody know how this can be solved?

/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/distributed/launch.py:178: FutureWarning: The module torch.distributed.launch is deprecated
    and will be removed in future. Use torchrun.
    Note that --use_env is set by default in torchrun.
    If your script expects `--local_rank` argument to be set, please
    change it to read from `os.environ['LOCAL_RANK']` instead. See 
    https://pytorch.org/docs/stable/distributed.html#launch-utility for 
    further instructions
    
      warnings.warn(
    Traceback (most recent call last):
      File "/cluster/home/username/chatbot/gpt_j/train.py", line 294, in <module>
        main(sys.argv[1:])
      File "/cluster/home/username/chatbot/gpt_j/train.py", line 64, in main
        model = Model()
      File "/cluster/home/username/chatbot/gpt_j/model.py", line 43, in __init__
        self.model.to(self.device)
      File "/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/nn/modules/module.py", line 907, in to
        return self._apply(convert)
      File "/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/nn/modules/module.py", line 578, in _apply
        module._apply(fn)
      File "/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/nn/modules/module.py", line 578, in _apply
        module._apply(fn)
      File "/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/nn/modules/module.py", line 578, in _apply
        module._apply(fn)
      [Previous line repeated 1 more time]
      File "/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/nn/modules/module.py", line 601, in _apply
        param_applied = fn(param)
      File "/cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/lib64/python3.10/site-packages/torch/nn/modules/module.py", line 905, in convert
        return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
    RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable
    CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
    For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
    2022-08-20 02:34:24,834 WARNING:Using custom data configuration default-990e072ab094d8c6
    Fatal error condition occurred in /opt/vcpkg/buildtrees/aws-c-io/src/9e6648842a-364b708815.clean/source/event_loop.c:72: aws_thread_launch(&cleanup_thread, s_event_loop_destroy_async_thread_fn, el_group, &thread_options) == AWS_OP_SUCCESS
    Exiting Application
    ################################################################################
    Stack trace:
    ################################################################################
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x200af06) [0x2ae499958f06]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x20028e5) [0x2ae4999508e5]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x1f27e09) [0x2ae499875e09]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x200ba3d) [0x2ae499959a3d]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x1f25948) [0x2ae499873948]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x200ba3d) [0x2ae499959a3d]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x1ee0b46) [0x2ae49982eb46]
    /cluster/home/username/.local/lib/python3.10/site-packages/pyarrow/libarrow.so.900(+0x194546a) [0x2ae49929346a]
    /lib64/libc.so.6(+0x39ce9) [0x2ae49001fce9]
    /lib64/libc.so.6(+0x39d37) [0x2ae49001fd37]
    /lib64/libc.so.6(__libc_start_main+0xfc) [0x2ae49000855c]
    /cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/bin/python() [0x4006fe]
    2022-08-20 02:34:24,947 WARNING:Reusing dataset text (/cluster/home/username/.cache/huggingface/datasets/text/default-990e072ab094d8c6/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad)
    
      0%|          | 0/1 [00:00<?, ?it/s]
    100%|██████████| 1/1 [00:00<00:00,  3.57it/s]
    100%|██████████| 1/1 [00:00<00:00,  3.56it/s]
    2022-08-20 02:34:25,663 WARNING:Using custom data configuration default-e89076d74da83269
    2022-08-20 02:34:25,669 WARNING:Reusing dataset text (/cluster/home/username/.cache/huggingface/datasets/text/default-e89076d74da83269/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad)
    
      0%|          | 0/1 [00:00<?, ?it/s]
    100%|██████████| 1/1 [00:00<00:00,  9.74it/s]
    100%|██████████| 1/1 [00:00<00:00,  9.71it/s]
    DatasetDict({
        train: Dataset({
            features: ['text'],
            num_rows: 787650
        })
        validation: Dataset({
            features: ['text'],
            num_rows: 262548
        })
    })
    WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 123732 closing signal SIGTERM
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 123731) of binary: /cluster/apps/nss/gcc-8.2.0/python/3.10.4/x86_64/bin/python
    /cluster/shadow/.lsbatch/1660955521.229195199: line 8: 123724 Segmentation fault      (core dumped) CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node 2 train.py

Based on:

RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable

it seems you workstation has trouble finding or communicating with the GPUs.
Are you able to run any workload on these GPUs without DDP?

I’m running it on a cluster (the cluster uses the IBM LSF batch system). Without DDP the GPU is found and it works. The requested two GPUs are on a single node of the cluster.

Just to make sure: you are able to run PyTorch scripts for a single GPU using any available GPU, right? Or did you only test the default one (GPU0)?
In any case, could you check if your launch script is masking the GPUs by accident?

What do you mean by any available GPU? If I’m using a single GPU only one GPU is available.

In my launch script I’m not specifiying the GPU and I think Huggingface Trainer should also not do it.

I can send you my launch script if you want (I don’t want to make it public).

Edit: I was now running the following script on the cluster (without torch.distributed.launch):

import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_name(1))
print(torch.cuda.get_device_properties(0))
print(torch.cuda.get_device_properties(1))

The output is as follows:

True
2
Quadro RTX 6000
Quadro RTX 6000
_CudaDeviceProperties(name='Quadro RTX 6000', major=7, minor=5, total_memory=24220MB, multi_processor_count=72)
_CudaDeviceProperties(name='Quadro RTX 6000', major=7, minor=5, total_memory=24220MB, multi_processor_count=72)

Good news! I found the problem. In my scrpt I’m using Huggingface and I’m putting my model on GPU using the following code:

from transformers import GPTJForCausalLM
import torch
model = GPTJForCausalLM.from_pretrained(
                "EleutherAI/gpt-j-6B",
                revision="float16",
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True,
                use_cache=False,
                gradient_checkpointing=True
)
model.to("cuda")

This causes the error because I think it tries to put both instances on the same GPU. When I remove the model.to("cuda") line it works fine (but it is then not running on GPU I guess).

Do you know how I can put it on GPU?

Calling to('cuda') will move the object to the default GPU. However, the error still points to a potential setup issue, as PyTorch is not able to find any GPUs:

root@3527793d5857:/workspace# python -c "import torch; print(torch.randn(1).to('cuda'))"
tensor([-0.1380], device='cuda:0')
root@3527793d5857:/workspace# CUDA_VISIBLE_DEVICES="" python -c "import torch; print(torch.randn(1).to('cuda'))"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py", line 222, in _lazy_init
    torch._C._cuda_init()
RuntimeError: No CUDA GPUs are available
1 Like

Another thing you can try is to set cuda device for each rank of the process before the beginning of your training by setting with torch.cuda.set_device, which is a requirement before using NCCL pg. see doc Distributed communication package - torch.distributed — PyTorch 1.12 documentation

For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsiblity to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().

Thank you for your help.

How can I set the cuda device for each rank? With two GPUs that would be

torch.cuda.set_device(0) and torch.cuda.set_device(1) but how can I set this for the ranks?

Edit: I got it working using the following:

from transformers import GPTJForCausalLM
import torch
import torch.distributed as dist
dist.init_process_group("nccl")
model = GPTJForCausalLM.from_pretrained(
                "EleutherAI/gpt-j-6B",
                revision="float16",
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True,
                use_cache=False,
                gradient_checkpointing=True
)
model.to(device=dist.get_rank())