What is the correct way to launch pytorch multinode on slurm?

What is the correct way to launch pytorch multinode on slurm? Is the code below the correct way to launch pytorch multinode on slurm?

#!/bin/bash
#SBATCH --nodes=2
#SBATCH --gpus-per-node=4
#SBATCH --ntasks-per-node=1 
#SBATCH --cpus-per-task=40
#SBATCH --mem=188000M
#SBATCH --account=def-lkong
#SBATCH --job-name=bitdeit-new-extra-res-small
#SBATCH --output=%x-%j.out
#SBATCH --time=0-168:00     # DD-HH:MM:SS

#SBATCH --mail-type=BEGIN
#SBATCH --mail-type=END
#SBATCH --mail-type=FAIL
#SBATCH --mail-type=REQUEUE


module load python/3.8

source ~/deit/bin/activate

# virtualenv --no-download ~/deit
# source ~/deit/bin/activate
# pip install --no-index --upgrade pip

# pip install torch --no-index
# pip install torchvision --no-index
# pip install timm --no-index
# pip install transformers --no-index

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

# echo "start extracting"
# date
# srun tar -xf /home/lding1/projects/def-lkong/lding1/datasets/ImageNet/imagenet.tar -C $SLURM_TMPDIR
# echo "end extracting"
# date

echo "start extracting"
date

slurm_hl2hl.py --format MPIHOSTLIST | sed -e 's/^/ssh /g' -e "s%\$% \"cd \$SLURM_TMPDIR;hostname;cp /home/lding1/projects/def-lkong/lding1/datasets/ImageNet/imagenet.tar .; tar -xf imagenet.tar\"%g" | uniq >   hz_tmp
# bash ./hz_tmp
echo "end extracting"
date

export LOGLEVEL=INFO

srun --ntasks=$SLURM_NNODES torchrun \
    --nnodes=2 \
    --rdzv_id $RANDOM \
    --rdzv_backend c10d \
    --rdzv_endpoint $head_node_ip:29500 \
    --nproc_per_node=4 main.py \
    --num-workers=40 \
    --batch-size=64 \
    --epochs=300 \
    --model=configs/deit-small-patch16-224 \
    --dropout=0.0 \
    --drop-path=0.0 \
    --opt=adamw \
    --sched=cosine \
    --weight-decay=0.00 \
    --lr=5e-4 \
    --warmup-epochs=0 \
    --color-jitter=0.0 \
    --aa=noaug \
    --reprob=0.0 \
    --mixup=0.0 \
    --cutmix=0.0 \
    --data-path=$SLURM_TMPDIR \
    --output-dir=logs/bitdeit-new-extra-res-small \
    --teacher-model=configs/deit-small-patch16-224 \
    --teacher-model-file=logs/deit-small-patch16-224/best.pth \
    --model-type=extra-res \
    --weight-bits=1 \
    --input-bits=1 \
    --att-prob-quantizer-type=bit \
    # --replace-ln-bn \
    # --resume=logs/bitdeit-new-extra-res-small/checkpoint.pth \
    # --current-best-model=logs/bitdeit-new-extra-res-small/best.pth \

I’m not a slurm expert but that script appears to ask slurm to launch on 2 nodes and on each node invoke torchrun with 4 procs. The issue I suspect here is that each group of 4 torchrun procs would not know about the other group.

One way to do this is to skip torchrun and write your own launcher script. The idea here would be that slurm creates a process per node, and then your script spawns more proceses but sets up the env variables that torch.distributed/c10d expects (e.g. RANK, WORLD_SIZE, …) and then calls torch.distributed.init_process_group. Here is an example I just found by searching (haven’t tried it, but it looks like it would work): Multi-node-training on slurm with PyTorch · GitHub. Note that in the example, the sbatch is configured for 1 task per node, and 4 nodes. I think you should be able to change this to N tasks per node, to have 1 task per GPU.