What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?

I want (the proper and official - bug free way) to do:

  1. resume from a checkpoint to continue training on multiple gpus
  2. save checkpoint correctly during training with multiple gpus

For that my guess is the following:

  1. to do 1 we have all the processes load the checkpoint from the file, then call DDP(mdl) for each process. I assume the checkpoint saved a ddp_mdl.module.state_dict().
  2. to do 2 simply check who is rank = 0 and have that one do the torch.save({‘model’: ddp_mdl.module.state_dict()})

Approximate code:

def save_ckpt(rank, ddp_model, path):
    if rank == 0:
        state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
        torch.save(state, path)

def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
    # loads to
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

Is this correct?


One of the reasons that I am asking is that distributed code can go subtly wrong. I want to make sure this does not happen to me. Of course I want to avoid deadlocks but that would be obvious if it happens to me (e.g. perhaps it could happen if all the processes somehow tried to open the same ckpt file at the same time. In that case I’d somehow make sure that only one of them loads it one at a time or have rank 0 only load it and then send it to the rest of the processes).

I am also asking because the official docs don’t make sense to me. I will paste their code and explanation since links can die sometimes:

Save and Load Checkpoints
It’s common to use torch.save and torch.load to checkpoint modules during training and recover from checkpoints. See SAVING AND LOADING MODELS for more details. When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead. This is correct because all processes start from the same parameters and gradients are synchronized in backward passes, and hence optimizers should keep setting parameters to the same values. If you use this optimization, make sure all processes do not start loading before the saving is finished. Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others’ devices. If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. For more advanced failure recovery and elasticity support, please refer to TorchElastic.

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

Related:

1 Like

Hi @Brando_Miranda,

Yes, your understanding is correct. Save the model (or any other sync’ed artifacts) to a permanent store on rank 0 during checkpointing, and load it on all ranks while resuming. What our documentation describes is the common idiom for saving/loading checkpoints, but if you have different requirements, you don’t have to follow it. For instance if your underlying store requires some form of sequential access, you can coordinate your workers using collective calls or coordinate them via a Store (e.g. TCPStore). In our example we have a basic demonstration of this coordination via a barrier call.

Do you mind explaining why the doc did not make sense to you? We always aim to improve our documentation, so any concrete feedback would be greatly appreciated.

1 Like