What is the correct way to launch pytorch multinode on slurm? Is the code below the correct way to launch pytorch multinode on slurm?
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --gpus-per-node=4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=40
#SBATCH --mem=188000M
#SBATCH --account=def-lkong
#SBATCH --job-name=bitdeit-new-extra-res-small
#SBATCH --output=%x-%j.out
#SBATCH --time=0-168:00 # DD-HH:MM:SS
#SBATCH --mail-type=BEGIN
#SBATCH --mail-type=END
#SBATCH --mail-type=FAIL
#SBATCH --mail-type=REQUEUE
module load python/3.8
source ~/deit/bin/activate
# virtualenv --no-download ~/deit
# source ~/deit/bin/activate
# pip install --no-index --upgrade pip
# pip install torch --no-index
# pip install torchvision --no-index
# pip install timm --no-index
# pip install transformers --no-index
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
# echo "start extracting"
# date
# srun tar -xf /home/lding1/projects/def-lkong/lding1/datasets/ImageNet/imagenet.tar -C $SLURM_TMPDIR
# echo "end extracting"
# date
echo "start extracting"
date
slurm_hl2hl.py --format MPIHOSTLIST | sed -e 's/^/ssh /g' -e "s%\$% \"cd \$SLURM_TMPDIR;hostname;cp /home/lding1/projects/def-lkong/lding1/datasets/ImageNet/imagenet.tar .; tar -xf imagenet.tar\"%g" | uniq > hz_tmp
# bash ./hz_tmp
echo "end extracting"
date
export LOGLEVEL=INFO
srun --ntasks=$SLURM_NNODES torchrun \
--nnodes=2 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
--nproc_per_node=4 main.py \
--num-workers=40 \
--batch-size=64 \
--epochs=300 \
--model=configs/deit-small-patch16-224 \
--dropout=0.0 \
--drop-path=0.0 \
--opt=adamw \
--sched=cosine \
--weight-decay=0.00 \
--lr=5e-4 \
--warmup-epochs=0 \
--color-jitter=0.0 \
--aa=noaug \
--reprob=0.0 \
--mixup=0.0 \
--cutmix=0.0 \
--data-path=$SLURM_TMPDIR \
--output-dir=logs/bitdeit-new-extra-res-small \
--teacher-model=configs/deit-small-patch16-224 \
--teacher-model-file=logs/deit-small-patch16-224/best.pth \
--model-type=extra-res \
--weight-bits=1 \
--input-bits=1 \
--att-prob-quantizer-type=bit \
# --replace-ln-bn \
# --resume=logs/bitdeit-new-extra-res-small/checkpoint.pth \
# --current-best-model=logs/bitdeit-new-extra-res-small/best.pth \