When using pytorch elastic training, how to broadcast the latest checkpoint to the newest nodes?

Hello, All!

I’d like to ask how to broadcast the latest checkpoint from the existing nodes to the newest node when I use elastic distributed training?

Assuming that I submit an elastic training job to a GPU cluster with nnodes=1:2. And I launch training script by following command:

MIN_SIZE=1 MAX_SIZE=2 \
python -m torch.distributed.run \
    --nnodes=$MIN_SIZE:$MAX_SIZE \
    --nproc_per_node=4 \
    --rzdv_id=$JOB_ID \
    --rzdv_endpoint=$ADDR:$PORT \
    --rdzv_backend=c10d \
    my_train_script.py

Initially, I got the first node with 4 GPUs. And the training will start because I set MIN_SIZE=1. In my_train_script.py, I saved checkpoint at the end of every epochs.

After some training steps(let’s say 3 epochs), I got the second node with 4 GPUs and I run the aforementioned command on the second node. I want that the training workers on the second node can load the checkpoint on the first node so that all the workers will have same model state, optimizer state, etc. Could anyone please to tell me how to implement it?

It will make it a lot easier if the checkpoint was written to a shared storage.

If the storage is local, one way to implement this can be binding ranks 0-3 to the first node, and always broadcast model states from rank0 to all other ranks before training loop starts. DistributedDataParallel — PyTorch 1.12 documentation already does this for you in the constructor.