Checkpointing DDP.module instead of DDP itself

I’d like to double-check because I see multiple question and the official doc is confusing to me. I want to do:

  1. resume from a checkpoint to continue training on multiple gpus
  2. save checkpoint correctly during training with multiple gpus

For that my guess is the following:

  1. to do 1 we have all the processes load the checkpoint from the file, then call DDP(mdl) for each process. I assume the checkpoint saved a ddp_mdl.module.state_dict().
  2. to do 2 simply check who is rank = 0 and have that one do the torch.save({‘model’: ddp_mdl.module.state_dict()})

Is this correct?

I am not sure why there are so many other posts (also this) or what the official doc is talking about. I don’t get why I’d want to optimize writing (or reading) once given that I will train my model from either 100 epochs to even 600K iterations…that one write seems useless to me.


I think the correct code is this one:

def save_ckpt(rank, ddp_model, path):
    if rank == 0:
        state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
        torch.save(state, path)

def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
    # loads to
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

also, I added more details here: python - What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch? - Stack Overflow