Preferred way of loading checkpoint with DDP training

Hello, I am using DDP to distribute training across multiple GPUs. Right now, I want to continue training with a checkpoint weight. I tried the following two ways of loading the checkpoint, and I would like to know what is the preferred way of loading the checkpoint.

First, load the model on the CPU first, and then wrap it with DDP.

torch.distributed.init_process_group("nccl", init_method="env://")
local_rank = int(os.environ.get("LOCAL_RANK"))
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank)  # Set up the device for this process.
torch.cuda.empty_cache()  

model = net()
model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu"))

device = torch.cuda.current_device()
model = model.to(device)
ddp_model = DistributedDataParallel(model, device_ids=[local_rank])

Second, wrap the model with DPP first and then load the weight to ddp_model.module on each process.

torch.distributed.init_process_group("nccl", init_method="env://")
local_rank = int(os.environ.get("LOCAL_RANK"))
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank)  # Set up the device for this process.
torch.cuda.empty_cache()  

device = torch.cuda.current_device()
model = net().to(device)
ddp_model = DistributedDataParallel(model, device_ids=[local_rank])

ddp_model.module.load_state_dict(
                torch.load(model_ckpt_path, map_location=torch.device('cuda', device)))

torch.distributed.barrier()  #waiting for all processes to complete the ckpt loading.

Is there any difference in the result of them? I personally preferred with second way, since I can wrap the load_ckpt method in an trainer class.

In the document of DistributedDataParallel, there’s a warning says:

You should never try to change your model’s parameters after wrapping up your model with DistributedDataParallel . Because, when wrapping up your model with DistributedDataParallel , the constructor of DistributedDataParallel will register the additional gradient reduction functions on all the parameters of the model itself at the time of construction. If you change the model’s parameters afterwards, gradient reduction functions no longer match the correct set of parameters.

So I think the first way should be the safer way.