Code freezes with Distributed Data Parallel at init_process_group

Hi, I am new to using DDP and PyTorch. I am trying to train a GNN model using DDP, and I am using the DGL library for handling my graph dataset and loading the dataset, etc. I have the following code for initializing the process group for the code:

def is_global_master(args: argparse.Namespace) -> bool:
    return args.rank == 0

def is_local_master(args: argparse.Namespace) -> bool:
    return args.local_rank == 0

def is_master(args: argparse.Namespace, local: bool = False) -> bool:
    return is_local_master(args) if local else is_global_master(args)

def is_using_distributed() -> bool:
    if 'WORLD_SIZE' in os.environ:
        print("\nWORLD_SIZE in os.environ\n")
        return int(os.environ['WORLD_SIZE']) > 1
    return False

def world_info_from_env() -> "tuple[int, int, int]":
    local_rank, rank, world_size = 0, 0, 1
    if 'LOCAL_RANK' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
    if 'RANK' in os.environ:
        rank = int(os.environ['RANK'])
    if 'WORLD_SIZE' in os.environ:
        world_size = int(os.environ['WORLD_SIZE'])

    return local_rank, rank, world_size

def init_distributed_device(args: argparse.Namespace):
    assert args.device_mode in ('cuda', 'cpu'), f'{args.device_mode=} not supported'
    # Distributed training = training on more than one GPU.
    # Works in both single and multi-node scenarios.
    args.distributed = False
    is_distributed = is_using_distributed()
    print(f"\nIs using distributed = {is_distributed}\n")
    if is_distributed:
        # DDP via torchrun, torch.distributed.launch
        args.local_rank, args.rank, args.world_size = world_info_from_env()
        # find new available port
        if not _is_free_port(os.environ["MASTER_PORT"]) and is_master(args):
            print("Set MASTER_PORT not free, searching for new free port. \n")
            new_port = _find_free_port()
            os.environ["MASTER_PORT"] = str(new_port)
            print(f'find {new_port=}')
        print("Free port found. \n")
        if args.dist_backend == 'nccl':
            os.environ["NCCL_BLOCKING_WAIT"] = '1'
        print(f"args.dist_backend = {args.dist_backend} \n")
        print("Completed init_process_group")
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        args.distributed = True

    if args.device_mode == 'cuda' and torch.cuda.is_available():
        if args.distributed and not args.no_set_device_rank:
            device = f'cuda:{args.local_rank}'
            device = 'cuda:0'
    elif args.device_mode == 'cpu':
        device = 'cpu'
    args.device = device

    if is_using_distributed():
        if is_master(args):
            print(f'Distributed mode enabled. {args.world_size=}')
        print('Not using distributed mode.')

def setup_print_for_distributed(args: argparse.Namespace):
    import builtins
    builtin_print = builtins.print

    def master_only_print(*print_args, **print_kwargs):
        force = print_kwargs.pop("force", False)
        if is_master(args) or force:
            builtin_print(*print_args, **print_kwargs)

    builtins.print = master_only_print

def _is_free_port(port: str | int) -> bool:
    port = int(port)
    ips = socket.gethostbyname_ex(socket.gethostname())[-1]
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return all(s.connect_ex((ip, port)) != 0 for ip in ips)

def _find_free_port() -> int:
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # Binding to port 0 will cause the OS to find an available port for us
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    # NOTE: there is still a chance the port could be taken by other processes.
    return int(port)

This is called from the main method as follows:

def main():
    args = DefaultArgs(

if __name__ == '__main__':
    ngpus = torch.cuda.device_count() 
    print(f"No. of GPUs = {ngpus} \n")
    assert ngpus >= 2, f"Requires at least 2 GPUs to run, got {ngpus} GPUs."

I am working on a SLURM cluster. I have a single node, and 2 GPUs, 1 process per GPU. I am running using torchrun --nnodes=1 --nproc_per_node=2 (not sure if I am doing this correct).
The code freezes everytime at the init_process_group, for at least one process, and here is an example output that I get:

No. of GPUs = 2 
WORLD_SIZE in os.environ
Is using distributed = True
Free port found. 
args.dist_backend = nccl 

No. of GPUs = 2 
WORLD_SIZE in os.environ
Is using distributed = True
Set MASTER_PORT not free, searching for new free port. 
find new_port=57085
Free port found. 
args.dist_backend = nccl 
Completed init_process_group

I am setting the environment variables manually as world_size=2, rank=0, local_rank=0, master_addr=‘localhost’, and master_port=‘12345’. What am I doing wrong?

It would be a big help if I could get any insights on this. Thanks!

Why did you only set rank=0? That looks suspicious to me.

I set it with the understanding that I would have to set the rank for the current process. But I might be wrong because I don’t fully understand this yet. Would you suggest doing something different?

I am setting the environment variables manually as world_size=2, rank=0, local_rank=0, master_addr=‘localhost’, and master_port=‘12345’. What am I doing wrong?

No need to set the environment variables manually since master_addr, master_port, world_size, rank, local_rank, among others are set automatically using torchrun (torchrun (Elastic Launch) — PyTorch master documentation)