Torchrun failing with MIGs ? : CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch

Helllo,

I’m struggling to find the way to run a training on a single node, multi GPU. The host is a DGX-A100, and the A100 has been split with MIGs. I did allocate 2 MIGs for my experiment.

I took the code from pytorch examples - https://github.com/pytorch/examples/tree/main/distributed/ddp-tutorial-series

But it fails to run :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ torchrun --standalone --nnodes=1 --nproc-per-node=2 multigpu_torchrun.py  50 10
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.                                                                                                                          
WARNING:torch.distributed.run:                                                                                                                                                                                     
*****************************************                                                                                                                                                                          
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.   
*****************************************                                                                                                                                                                          
Traceback (most recent call last):                                                                                                                                                                                 
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 260, in _lazy_init                                                                                                     
    queued_call()                                                                                                                                                                                                  
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 145, in _check_capability                                                                                              
    capability = get_device_capability(d)                                                                                                                                                                          
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 381, in get_device_capability
    prop = get_device_properties(device)                                                                                                                                                                           
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 399, in get_device_properties
    return _get_device_properties(device)  # type: ignore[name-defined]                                                                                                                                            RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):                                                                                                                                                                                   File "multigpu_torchrun.py", line 111, in <module>                                                                                                                                                                   main(args.save_every, args.total_epochs, args.batch_size)                                                                                                                                                        File "multigpu_torchrun.py", line 95, in main                                                                                                                                                                        ddp_setup()                                                                                                                                                                                                      File "multigpu_torchrun.py", line 15, in ddp_setup                                                                                                                                                                   torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))            
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 350, in set_device                  
    torch._C._cuda_setDevice(device)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 264, in _lazy_init
    raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug 
to PyTorch.                  
                                                                                                                                                                                                                   
CUDA call was originally invoked at:
[...]                                                           

Reading there or there, I saw it might be linked to the backend and tried either nccl or gloo with no sucess.

Actually, even trying to collect infos about the environment fails :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ python3 -m torch.utils.collect_env
Collecting environment information...
Traceback (most recent call last):
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 260, in _lazy_init
    queued_call()
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 145, in _check_capability
    capability = get_device_capability(d)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 381, in get_device_capability
    prop = get_device_properties(device)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 399, in get_device_properties
    return _get_device_properties(device)  # type: ignore[name-defined]
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/utils/collect_env.py", line 602, in <module>
    main()
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/utils/collect_env.py", line 585, in main
    output = get_pretty_env_info()
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/utils/collect_env.py", line 580, in get_pretty_env_info
    return pretty_str(get_env_info())
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/utils/collect_env.py", line 451, in get_env_info
    cuda_module_loading=get_cuda_module_loading_config(),
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/utils/collect_env.py", line 406, in get_cuda_module_loading_config
    torch.cuda.init()
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 216, in init
    _lazy_init()
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 264, in _lazy_init
    raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. 

CUDA call was originally invoked at:

['  File "/usr/lib/python3.8/runpy.py", line 185, in _run_module_as_main\n    mod_name, mod_spec, code = _get_module_details(mod_name, _Error)\n', '  File "/usr/lib/python3.8/runpy.py", line 111, in _get_module_details\n    __import__(pkg_name)\n', '  File "<frozen importlib._bootstrap>", line 991, in _find_and_load\n', '  File "<frozen importlib._bootstrap>", line 961, in _find_and_load_unlocked\n', '  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed\n', '  File "<frozen importlib._bootstrap>", line 991, in _find_and_load\n', '  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked\n', '  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked\n', '  File "<frozen importlib._bootstrap_external>", line 848, in exec_module\n', '  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed\n', '  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/__init__.py", line 1146, in <module>\n    _C._initExtension(manager_path())\n', '  File "<frozen importlib._bootstrap>", line 991, in _find_and_load\n', '  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked\n', '  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked\n', '  File "<frozen importlib._bootstrap_external>", line 848, in exec_module\n', '  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed\n', '  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 197, in <module>\n    _lazy_call(_check_capability)\n', '  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/cuda/__init__.py", line 195, in _lazy_call\n    _queued_calls.append((callable, traceback.format_stack()))\n']

For the environment, here are some infos :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ python3 -c "import torch; print(torch.cuda.is_available())"
True
fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ python3 -c "import torch; print(torch.cuda.device_count())"
2

For the CUDA_VISIBLE_DEVICES :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ echo $CUDA_VISIBLE_DEVICES 
0,1

and the nvidia-smi output :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ nvidia-smi
Thu Jul 20 16:43:59 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:01:00.0 Off |                   On |
| N/A   49C    P0    57W / 275W |     45MiB / 81920MiB |     N/A      Default |
|                               |                      |              Enabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  Off  | 00000000:47:00.0 Off |                   On |
| N/A   49C    P0    66W / 275W |     45MiB / 81920MiB |     N/A      Default |
|                               |                      |              Enabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  Off  | 00000000:81:00.0 Off |                   On |
| N/A   49C    P0    58W / 275W |     45MiB / 81920MiB |     N/A      Default |
|                               |                      |              Enabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA DGX Display  Off  | 00000000:C1:00.0 Off |                  N/A |
| 34%   45C    P8    N/A /  50W |      1MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  Off  | 00000000:C2:00.0 Off |                   On |
| N/A   48C    P0    56W / 275W |     48MiB / 81920MiB |     N/A      Default |
|                               |                      |              Enabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| MIG devices:                                                                |
+------------------+----------------------+-----------+-----------------------+
| GPU  GI  CI  MIG |         Memory-Usage |        Vol|         Shared        |
|      ID  ID  Dev |           BAR1-Usage | SM     Unc| CE  ENC  DEC  OFA  JPG|
|                  |                      |        ECC|                       |
|==================+======================+===========+=======================|
|  0    7   0   0  |      6MiB /  9728MiB | 14      0 |  1   0    0    0    0 |
|                  |      0MiB / 16383MiB |           |                       |
+------------------+----------------------+-----------+-----------------------+
|  0    8   0   1  |      6MiB /  9728MiB | 14      0 |  1   0    0    0    0 |
|                  |      0MiB / 16383MiB |           |                       |
+------------------+----------------------+-----------+-----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

and the libraries :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ python3 -m pip list | grep torch
pytorch-lightning        1.9.5               
torch                    2.0.1               
torchmetrics             1.0.1               
torchvision              0.15.2
fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ python3 --version
Python 3.8.10

Do you think this might be related to the way we configured the MIG or is there something that pops in your mind that I’m not doing the right way ?

Thanks for your help

edit: I forgot to mention the code is working with 1 process, 1MIG:

fix_jer@dgxa100:~$ srun --partition=interactive10 --gres=gpu:1g.10gb:1 --ntasks=1 --cpus-per-task=4 --pty bash
fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ torchrun --standalone --nnodes=1 --nproc-per-node=1 multigpu_torchrun.py  50 10
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
[GPU0] Epoch 0 | Batchsize: 32 | Steps: 64
Epoch 0 | Training snapshot saved at snapshot.pt
[GPU0] Epoch 1 | Batchsize: 32 | Steps: 64
[GPU0] Epoch 2 | Batchsize: 32 | Steps: 64
[GPU0] Epoch 3 | Batchsize: 32 | Steps: 64

This won’t work as only a single MIG slice can be used by a process.

Thank you for your quick answer.

I indeed read what you say , one process per MIG.

But I thought that multiple processes would be spawn with ddp or torchrun , or ? , each one using its own MIG. I mean, the same way we would be running a training distributed on multiple nodes, with one master process.

What would be the proper way to do that ? (I do not know if that helps, but the DGX is actually managed with slurm)

You would either need to use multiple devices in DDP or a single process on a single MIG slice. Multi-MIG slices in any setup are not supported.

Sorry if there is one point I misunderstanding.

I had in mind the (strange, I agree :slight_smile: ) use of multinode training on the same host, relying on the idea that I do run multiple slurm jobs, each with its own single MIG slice , each of these processes would only use their own MIG.

More precisely, I tried the following, which actually failed so far

  • one allocation on the dgxa100 with
:~$ torchrun --nproc-per-node=1 --nnodes=2 --node_rank=0 --rdzv-id=456 --rdzv-backend=c10d --rdzv-endpoint=138.195.195.185:29603 multinode.py 50 10
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
Traceback (most recent call last):
  File "multinode.py", line 112, in <module>
    main(args.save_every, args.total_epochs, args.batch_size)
  File "multinode.py", line 99, in main
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
  File "multinode.py", line 38, in __init__
    self.model = DDP(self.model, device_ids=[self.local_rank])
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.
Last error:
Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1000
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1160168) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/raid/home/fix_jer/.local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
multinode.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-07-21_17:11:08
  host      : xxxxx
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1160168)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
  • a second srun on the same dgxa100 host with :
torchrun --nproc-per-node=1 --nnodes=2 --node_rank=1 --rdzv-id=456 --rdzv-backend=c10d --rdzv-endpoint=138.195.195.185:29603 multinode.py 50 10
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
[W socket.cpp:426] [c10d] The server socket has failed to bind to [::]:29603 (errno: 98 - Address already in use).
[W socket.cpp:426] [c10d] The server socket has failed to bind to ?UNKNOWN? (errno: 98 - Address already in use).
[E socket.cpp:462] [c10d] The server socket has failed to listen on any local network address.
Traceback (most recent call last):
  File "multinode.py", line 112, in <module>
    main(args.save_every, args.total_epochs, args.batch_size)
  File "multinode.py", line 99, in main
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
  File "multinode.py", line 38, in __init__
    self.model = DDP(self.model, device_ids=[self.local_rank])
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.
Last error:
Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 1000
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1160169) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/raid/home/fix_jer/.local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
multinode.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-07-21_17:11:08
  host      : xxxxx
  rank      : 1 (local_rank: 0)
  exitcode  : 1 (pid: 1160169)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Do you think that should be working ? Do you know what I’m doing wrong in this use of two separate processes on the same host ?

This would be a valid setup, but you are trying to use multiple MIG slices via torchrun, which is not supported. You can use a single MIG slice in a single process without multi-GPU support.

Yes, I agree with the objective : one MIG slice in one process. That’s the reason why I changed the way I tried to run the code in my last comment.

In my last run : two torchrun calls, each in its own slurm allocation. In each of my slurm allocation, I requested only one MIG slice. And each torchrun is running a single process --nproc-per-node 1

With respect to my last comment, there was another issue : both torchrun processes were trying to start the backend store. With this fixed, let me try to be clearer of my current, not yet working setup .

How I allocate two MIGS, independently

I do the two allocations with slurm (although some elements are specific to our slurm setup):

~:$ srun --partition=interactive10 --gres=gpu:1g.10gb:1 --ntasks=1 --cpus-per-task=4 --pty bash

This gives me access to a single MIG slice :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ echo $CUDA_VISIBLE_DEVICES 
0
fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ python3 -c "import torch; print(torch.cuda.device_count())"
1

And this is true for my two parallel sessions.

Then I run , in the first slurm session :

torchrun --nproc-per-node=1 --nnodes=2 --node_rank=0 --rdzv-id=456 --rdzv-backend=c10d --rdzv_conf=is_host=1 --rdzv-endpoint=138.195.195.185:29603 
 multinode.py 50 10

and in the second slurm session

torchrun --nproc-per-node=1 --nnodes=2 --node_rank=1 --rdzv-id=456 --rdzv-backend=c10d --rdzv-conf=is_host=0 --rdzv-endpoint=138.195.195.185:29603 multinode.py 50 10

The only difference between the two calls is --node_rank (either 0 or 1) and --rdzv-conf=is_host= (either 0 or 1)

However, they both fail with :

[...]
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
Last error:
Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 1000
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2051601) of binary: /usr/bin/python3
[...]
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
Last error:
Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1000
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2051602) of binary: /usr/bin/python3

It seems to me that, in addition to one MIG per process, there is another constrain that even if MIG slices are used in independent processes, they should not be on the same device. I do not know if this makes sense.

I gave a last try with the first process involving one MIG on one GPU and the other process involving one MIG from another GPU.
This time, the nvidia-smi returns respectively :

slurm job1: ~$ nvidia-smi
[....]
+-----------------------------------------------------------------------------+
| MIG devices:                                                                |
+------------------+----------------------+-----------+-----------------------+
| GPU  GI  CI  MIG |         Memory-Usage |        Vol|         Shared        |
|      ID  ID  Dev |           BAR1-Usage | SM     Unc| CE  ENC  DEC  OFA  JPG|
|                  |                      |        ECC|                       |
|==================+======================+===========+=======================|
|  0    7   0   0  |      6MiB /  9728MiB | 14      0 |  1   0    0    0    0 |
|                  |      0MiB / 16383MiB |           |                       |
+------------------+----------------------+-----------+-----------------------+

slurm job2: ~$ nvidia-smi
[...]
+-----------------------------------------------------------------------------+
| MIG devices:                                                                |
+------------------+----------------------+-----------+-----------------------+
| GPU  GI  CI  MIG |         Memory-Usage |        Vol|         Shared        |
|      ID  ID  Dev |           BAR1-Usage | SM     Unc| CE  ENC  DEC  OFA  JPG|
|                  |                      |        ECC|                       |
|==================+======================+===========+=======================|
|  4    1   0   0  |     22MiB / 40192MiB | 42      0 |  3   0    2    0    0 |
|                  |      0MiB / 65535MiB |           |                       |
+------------------+----------------------+-----------+-----------------------+

But now, the torchrun calls fail with another error, a Cuda failure: 'invalid argument' .

For the first process :

fix_jer@dgxa100:~/GIT/examples/distributed/ddp-tutorial-series$ torchrun --nproc-per-node=1 --nnodes=2 --node_rank=0 --rdzv-backend=c10d --rdzv-conf=is_host=1 --rdzv-endpoint=138.195.195.185:29603  multinode.py 
50 10                                                                                                    
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
Traceback (most recent call last):                                                                       
  File "multinode.py", line 112, in <module>                                                             
    main(args.save_every, args.total_epochs, args.batch_size)                  
  File "multinode.py", line 99, in main                                                                  
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) 
  File "multinode.py", line 38, in __init__                                                              
    self.model = DDP(self.model, device_ids=[self.local_rank])                 
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)       
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.                                                                
Last error:                                                                                              
Cuda failure 'invalid argument'                                                                          
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2072136) of binary: /usr/bin/python3
Traceback (most recent call last):                                                                       
  File "/raid/home/fix_jer/.local/bin/torchrun", line 8, in <module>           
    sys.exit(main())                                                                                     
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)                                                                            
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run 
    elastic_launch(                                                                                      
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))            
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(                                                                              
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:             
============================================================                   
multinode.py FAILED                                                                                      
------------------------------------------------------------                   
Failures:                                                                                                
  <NO_OTHER_FAILURES>                                                                                    
------------------------------------------------------------                   
Root Cause (first observed failure):                                                                     
[0]:                                                                                                     
  time      : 2023-07-23_06:32:21                                                                        
  host      : hubia-dgx.centralesupelec.fr                                                               
  rank      : 1 (local_rank: 0)                                                                          
  exitcode  : 1 (pid: 2072136)     
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
           

and for the second ::

torchrun --nproc-per-node=1 --nnodes=2 --node_rank=1 --rdzv-backend=c10d --rdzv-conf=is_host=0 --rdzv-endpoint=138.195.195.185:29603  multinode.py 50 10
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
Traceback (most recent call last):
  File "multinode.py", line 112, in <module>
    main(args.save_every, args.total_epochs, args.batch_size)
  File "multinode.py", line 99, in main
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
  File "multinode.py", line 38, in __init__
    self.model = DDP(self.model, device_ids=[self.local_rank])
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.
Last error:
Cuda failure 'invalid argument'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2072137) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/raid/home/fix_jer/.local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/raid/home/fix_jer/.local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
multinode.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-07-23_06:32:21
  host      : hubia-dgx.centralesupelec.fr
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 2072137)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

You are still trying to create a NCCL communication between two ranks, which then tries to use the same device, and raises the error. I’m still not completely understanding why you are trying to use multiple processes or even NCCL since a MIG slice can only be accessed in a single process without any communication. So could you explain why NCCL is involved at all?

Maybe I’m doing it the wrong way, with the wrong backend for distribution. All this is pretty new to me.

What I’m trying to do is the following.

I would like to run a distributed training accross multiple MIGs. Given the constrain : a single MIG for a process, I have the feeling the way to do is to proceed as if I was distributing a training accross multiple hosts, i.e. with servier/clients.

Therefore, I run multiple slurm allocations on the same host, each with a single MIG slice and then I’m wondering about the way to run the processes on each slurm allocation.

For the doc, It seems to me NCCL can be used for mutli-node distributed training so I thought it would also allow communication between mutiple processes on the same host, given that I’m explicitely specifying which one is hosting the store.

But then , I’m probably wrong with the argument provided to torchrun . The rationale I had was :s

  • nproc-per-node = 1 to enforce the constrain 1 process for 1 MIG
  • --nodes=2 because the training is distributing over “virtually” 2 nodes, although hosted on the same host
  • --node_rank =0 or 1 : I thought this should be specified for training over multiple nodes (oh ! I notice that I have a typo node_rank instead of node-rank, not sure if this is critical)
  • rdzv-conf=is_host=0 or 1 to indicate which process will host the store
  • rdvz-endpoint : set to the IP of my DGX host, port being randomly chosen
  • rdzv-backend=c10d : not a knowledged choice on my side, just took the value suggested in the ddp tutorial

Hope this is clearer. Would you suggest another combination of options to get this distributed training multnode like but running on the same host ?

Thank you for your help anyway !

There is no GPU-to-GPU P2P (both PCIe and NVLINK) support in MIG mode. This means no support for multi-GPU or multi-node training.