Checkpointing DDP.module instead of DDP itself

I am using DDP to distribute training across multiple gpu.

model = Net(...)
ddp_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 
ddp_model = DDP(ddp_model, device_ids=[gpu], find_unused_parameters=True)

When checkpointing, is it ok to save ddp_model.module instead of ddp_model? I need to be able to use the checkpoint for 1. evaluation using a single gpu 2. resume training with multiple gpus

def save_ckpt(ddp_model, path):
    state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
    torch.save(state, path)

def load_ckpt(path, distributed):
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'], strict=False)
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

I am not sure how this would behave, does manipulating ddp_model.module in general break things?
Similarly what if the number of gpus changes, would that impact the optimizer for example? does it need to be reinitialized? Thank you!

1 Like

Yep, this is actually the recommended way:

  1. On save, use one rank to save ddp_model.module to checkpoint.
  2. On load, first use the checkpoint to load a local model, and then wrap the local model instance with DDP on all processes (i.e., sth like DistributedDataParallel(local_model, device_ids=[rank])).
3 Likes

I’d like to double-check because I see multiple question and the official doc is confusing to me. I want to do:

  1. resume from a checkpoint to continue training on multiple gpus
  2. save checkpoint correctly during training with multiple gpus

For that my guess is the following:

  1. to do 1 we have all the processes load the checkpoint from the file, then call DDP(mdl) for each process. I assume the checkpoint saved a ddp_mdl.module.state_dict().
  2. to do 2 simply check who is rank = 0 and have that one do the torch.save({‘model’: ddp_mdl.module.state_dict()})

Is this correct?

I am not sure why there are so many other posts (also this) or what the official doc is talking about. I don’t get why I’d want to optimize writing (or reading) once given that I will train my model from either 100 epochs to even 600K iterations…that one write seems useless to me.


I think the correct code is this one:

def save_ckpt(rank, ddp_model, path):
    if rank == 0:
        state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
        torch.save(state, path)

def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
    # loads to
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

also, I added more details here: python - What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch? - Stack Overflow