Training on GPUs (From RuntimeError: Address already in use to timeout)

I am training an example code on two nodes with two GPUs per node.

I first got “RuntimeError: Address already in use,” and tried many suggested solutions posted here, but nothing worked, except removing --ntasks-per-node, leading to only one task being executed per node.

Once I did that, the process would time out after hours of waiting (not sure if any training was taking place) with no output. CPU and memory utilization seems very minimal.

Example code:
bash script:

#!/bin/bash
#SBATCH --job-name=ddp-torch     # create a short name for your job
#SBATCH --nodes=2                # node count
#SBATCH --cpus-per-task=8        # cpu-cores per task (>1 if multi-threaded tasks)
#SBATCH --mem=32G                # total memory per node (4 GB per cpu-core is default)
#SBATCH --gres=gpu:2             # number of gpus per node

export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export WORLD_SIZE=$(($SLURM_NNODES * 1))
echo "WORLD_SIZE="$WORLD_SIZE

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

module load gnu10
source  activate pytorch-env/bin/activate

srun python example.py

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse
from socket import gethostname


def example(rank, world_size):
    # create default process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()
    dist.destroy_process_group()

def main():
    parser = argparse.ArgumentParser(description='PyTorch Test')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["SLURM_PROCID"])
    gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"])
    assert gpus_per_node == torch.cuda.device_count()
    print(f"Hello from rank {rank} of {world_size} on {gethostname()} where the$
          f" {gpus_per_node} allocated GPUs per node.", flush=True)
    
    local_rank = rank - gpus_per_node * (rank // gpus_per_node)
    
    
    torch.cuda.set_device(local_rank)
    mp.spawn(example,
        args=(local_rank,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Looks like you’re passing local_rank to init_process_group while the the global rank should be passed in.

In your 2 host with 2 gpus each, you should have 4 processes each making a call with ranks with values from 0-3

1 Like

If you’re using slurm anyway, I would personally avoid using multiprocessing.

If you want to train on N nodes with M GPUs on each node, you can modify the slurm script with,
#SBATCH --ntasks-per-node=M where M is the same value set in #SBATCH --gres=gpu:M.

Slurm will set local rank variables and you can assign a GPU to each of these tasks. E.g.,

def main():
    world_size = int(os.environ['SLURM_NTASKS'])
    rank = int(os.environ["SLURM_PROCID"])
    local_rank = int(os.environ['SLURM_LOCALID']) 
    # or 
    # gpu  = torch.device(f"cuda:{os.environ['SLURM_LOCALID']}")
    example(local_rank, rank, world_size)

def example(local_rank, rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    ...
    model = nn.Linear().to(local_rank)
    ...
    ddp_model = DDP(model, device_ids=[local_rank])
    ...

It will function the same but the code is simpler and (I think) is more efficient than multiprocessing.

1 Like

Thank you. It worked when I removed multiprocessing.