Simple nccl scatter code snippet makes SIGSEGV

Hi, forum.
I just want to distribute some tensors and compute these distributed tensors in other GPUs and gather it to a single GPU with nccl scatter and gather. However, it makes Segmentation fault when codes called SimpleQueue.get().

Would you guys help me to understand this problem or solve this issue?

  File ".../python3.11/site-packages/torch/storage.py", line 955 in _new_shared_cuda
  File ".../python3.11/site-packages/torch/multiprocessing/reductions.py", line 121 in rebuild_cuda_tensor
  File ".../python3.11/multiprocessing/queues.py", line 367 in get

Below there is a simple code snippet that causes the SIGSEGV problem. This code snippet requires at least 2 GPUs.

import os
import multiprocessing
import faulthandler

import torch.distributed


def init_worker(gpu_id, input_tensors, fn_scatter, fn_compute):
    faulthandler.enable()

    name = multiprocessing.current_process().name

    print(f"{name}: BEGIN INIT WORKER {gpu_id}")

    n_gpu = torch.cuda.device_count()

    # torch.cuda.set_device(gpu_id)
    torch.distributed.init_process_group(
        backend="nccl",
        world_size=n_gpu,
        rank=gpu_id,
    )

    fn_scatter(input_tensors, gpu_id)
    fn_compute(input_tensors, gpu_id)

    print(f"{name}: END INIT WORKER {gpu_id}")


def do_scatter(input_tensors, rank: int):
    device = torch.device(f"cuda:{rank}")
    tensor = torch.empty(1, device=device)

    if rank == 0:
        tensor_list = []
        while input_tensors.empty() != True:
            tensor_list.append(input_tensors.get())

        torch.distributed.scatter(tensor, scatter_list=tensor_list, src=0)
    else:
        torch.distributed.scatter(tensor, scatter_list=[], src=0)

    print(f"GPU{rank}: {tensor}")

    input_tensors.put(tensor)


def do_compute(input_tensors, rank: int):
    if rank == 0:
        for _ in range(torch.cuda.device_count()):
            x = input_tensors.get()
            input_tensors.put(x * 2)


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn", force=True)
    faulthandler.enable()

    n_gpu = torch.cuda.device_count()
    print(f"Found {n_gpu} GPUs")

    master_addr = "localhost"
    master_port = "24206"
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port
    print(f"Master address is {master_addr}:{master_port}")

    input_tensors = torch.multiprocessing.SimpleQueue()
    for i in range(n_gpu):
        input_tensors.put(torch.tensor([i + 1], dtype=torch.float32, device="cuda:0"))

    torch.multiprocessing.spawn(
        fn=init_worker,
        args=(input_tensors, do_scatter, do_compute),
        nprocs=n_gpu,
    )

    print(input_tensors)

Now I’m using PyTorch 2.0.1+cu118 with Ubuntu 18.04. And RTX 3090 with NVLink

Self-answered

import os
import multiprocessing
import faulthandler

import torch.distributed


def init_worker(gpu_id, input_tensors, fn_scatter, fn_compute, fn_gather):
    faulthandler.enable()

    name = multiprocessing.current_process().name

    print(f"{name}: BEGIN INIT WORKER {gpu_id}")

    n_gpu = torch.cuda.device_count()

    # torch.cuda.set_device(gpu_id)
    torch.distributed.init_process_group(
        backend="nccl",
        world_size=n_gpu,
        rank=gpu_id,
    )

    scattered_tensor = fn_scatter(input_tensors, gpu_id)
    scattered_tensor = fn_compute(scattered_tensor)
    fn_gather(scattered_tensor, gpu_id)

    print(f"{name}: END INIT WORKER {gpu_id}")


def do_scatter(input_tensors, rank: int):
    device = torch.device(f"cuda:{rank}")
    tensor = torch.empty(1, device=device)

    if rank == 0:
        tensor_list = []
        while input_tensors.empty() != True:
            tensor_list.append(input_tensors.get())

        torch.distributed.scatter(tensor, scatter_list=tensor_list, src=0)
    else:
        torch.distributed.scatter(tensor, scatter_list=[], src=0)

    return tensor

def do_compute(tensor):
    return torch.mul(tensor, 10)


def do_gather(scattered_tensor, rank: int):
    device = torch.device(f"cuda:{rank}")
    tensor = scattered_tensor

    if rank == 0:
        tensor_list = [torch.empty(1, device=device) for _ in range(torch.cuda.device_count())]
        torch.distributed.gather(tensor, gather_list=tensor_list, dst=0)

        print(tensor_list)
    else:
        torch.distributed.gather(tensor, gather_list=[], dst=0)


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn", force=True)
    faulthandler.enable()

    n_gpu = torch.cuda.device_count()
    print(f"Found {n_gpu} GPUs")

    master_addr = "localhost"
    master_port = "24206"
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port
    print(f"Master address is {master_addr}:{master_port}")

    input_tensors = torch.multiprocessing.SimpleQueue()
    for i in range(n_gpu):
        input_tensors.put(torch.tensor([i + 1], dtype=torch.float32, device="cuda:0"))

    torch.multiprocessing.spawn(
        fn=init_worker,
        args=(input_tensors, do_scatter, do_compute, do_gather),
        nprocs=n_gpu,
    )