RuntimeError: CUDA error: initialization error when calling torch.distributed.init_process_group using torch multiprocessing

I created a pytest fixture using decorator to create multiple processes (using torch multiprocessing) for running model parallel distributed unit tests using pytorch distributed. I randomly encountered the below CUDA initialization error all of a sudden (when I was trying to fix some unit tests logic). Since then, all my unit tests have been failing and I traced the failure back to my pytest fixture which calls torch.distributed.init_process_group(…).

Error traceback:

$ python3 -m pytest test/test_distributed.py::test_dummy
Process Process-1:
Traceback (most recent call last):
  File "/usr/lib64/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib64/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/fsx-dev/FSxLustre20201016T182138Z/prraman/home/workspace/ws_M5_meg/src/M5ModelParallelism/test_script/commons_debug.py", line 34, in dist_init
    torch.distributed.init_process_group(backend, rank=rank, world_size=world_size, init_method=init_method)
  File "/usr/local/lib64/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 480, in init_process_group
    barrier()
  File "/usr/local/lib64/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 2186, in barrier
    work = _default_pg.barrier()
RuntimeError: CUDA error: initialization error

Below is the pytest fixture I created:

# file: test_distributed.py
import os
import time
import torch
import torch.distributed as dist
from torch.multiprocessing import Process, set_start_method
import pytest

# Worker timeout *after* the first worker has completed.
WORKER_TIMEOUT = 120


def distributed_test_debug(world_size=2, backend='nccl'):
    """A decorator for executing a function (e.g., a unit test) in a distributed manner.
        This decorator manages the spawning and joining of processes, initialization of
        torch.distributed, and catching of errors.

        Usage example:
        @distributed_test_debug(worker_size=[2,3])
        def my_test():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            assert(rank < world_size)

    Arguments:
        world_size (int or list): number of ranks to spawn. Can be a list to spawn
        multiple tests.
    """
    def dist_wrap(run_func):
        """Second-level decorator for dist_test. This actually wraps the function. """
        def dist_init(local_rank,
                      num_procs,
                      *func_args, **func_kwargs):
            """Initialize torch.distributed and execute the user function. """
            os.environ['MASTER_ADDR'] = '127.0.0.1'
            os.environ['MASTER_PORT'] = '29503'
            os.environ['LOCAL_RANK'] = str(local_rank)
            # NOTE: unit tests don't support multi-node so local_rank == global rank
            os.environ['RANK'] = str(local_rank)
            os.environ['WORLD_SIZE'] = str(num_procs)
            master_addr = os.environ['MASTER_ADDR']
            master_port = os.environ['MASTER_PORT']
            rank = local_rank

            # Initializes the default distributed process group, and this will also initialize the distributed package.
            init_method = "tcp://"
            init_method += master_addr + ":" + master_port
            print('inside dist_init, world_size: ', world_size)
            torch.distributed.init_process_group(backend, rank=rank, world_size=world_size, init_method=init_method)
            print("rank={} init complete".format(rank))

            #torch.distributed.destroy_process_group()
            # print("rank={} destroy complete".format(rank))

            if torch.distributed.get_rank() == 0:
                print('> testing initialize_model_parallel with size {} ...'.format(
                    2))

            if torch.cuda.is_available():
                torch.cuda.set_device(local_rank)

            run_func(*func_args, **func_kwargs)

        def dist_launcher(num_procs,
                          *func_args, **func_kwargs):
            """Launch processes and gracefully handle failures. """

            # Spawn all workers on subprocesses.
            #set_start_method('spawn')
            processes = []
            for local_rank in range(num_procs):
                p = Process(target=dist_init,
                            args=(local_rank,
                                  num_procs,
                                  *func_args),
                            kwargs=func_kwargs)
                p.start()
                processes.append(p)

            # Now loop and wait for a test to complete. The spin-wait here isn't a big
            # deal because the number of processes will be O(#GPUs) << O(#CPUs).
            any_done = False
            while not any_done:
                for p in processes:
                    if not p.is_alive():
                        any_done = True
                        break

            # Wait for all other processes to complete
            for p in processes:
                p.join(WORKER_TIMEOUT)

            failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
            for rank, p in failed:
                # If it still hasn't terminated, kill it because it hung.
                if p.exitcode is None:
                    p.terminate()
                    pytest.fail(f'Worker {rank} hung.', pytrace=False)
                if p.exitcode < 0:
                    pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
                                pytrace=False)
                if p.exitcode > 0:
                    pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
                                pytrace=False)

        def run_func_decorator(*func_args, **func_kwargs):
            """Entry point for @distributed_test(). """

            if isinstance(world_size, int):
                dist_launcher(world_size, *func_args, **func_kwargs)
            elif isinstance(world_size, list):
                for procs in world_size:
                    dist_launcher(procs, *func_args, **func_kwargs)
                    time.sleep(0.5)
            else:
                raise TypeError(f'world_size must be an integer or a list of integers.')

        return run_func_decorator

    return dist_wrap

Below is how I call the pytest fixture:

@distributed_test_debug(world_size=2)
def test_dummy():
    assert 1 == 1

I have seen some issues raised in the past when torch multiprocessing and CUDA not working well together, not sure if this is related to that. Perhaps a different way I should be creating my multiple processes to avoid this problem? Any help is appreciated.

I am using pytorch version: 1.8.0a0+ae5c2fe

You’re right that there can be a bunch of issues in getting CUDA + multiprocessing to work correctly, I’d suggest starting off reading here for more info: Multiprocessing best practices — PyTorch 1.10.0 documentation

The main recommendation is to try using the spawn start method via something like multiprocessing.set_start_method(...).

Alternatively, another recommendation is to try using torch.multiprocessing and the mp.spawn method as documented here: Multiprocessing package - torch.multiprocessing — PyTorch 1.10.0 documentation (check torch.multiprocessing.spawn).