Yep, this is actually the recommended way:
- On save, use one rank to save
ddp_model.module
to checkpoint. - On load, first use the checkpoint to load a local model, and then wrap the local model instance with DDP on all processes (i.e., sth like
DistributedDataParallel(local_model, device_ids=[rank])
).