Preferred way of loading checkpoint with DDP training

Hello, I am using DDP to distribute training across multiple GPUs. Right now, I want to continue training with a checkpoint weight. I tried the following two ways of loading the checkpoint, and I would like to know what is the preferred way of loading the checkpoint.

First, load the model on the CPU first, and then wrap it with DDP.

torch.distributed.init_process_group("nccl", init_method="env://")
local_rank = int(os.environ.get("LOCAL_RANK"))
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank)  # Set up the device for this process.
torch.cuda.empty_cache()  

model = net()
model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu"))

device = torch.cuda.current_device()
model = model.to(device)
ddp_model = DistributedDataParallel(model, device_ids=[local_rank])

Second, wrap the model with DPP first and then load the weight to ddp_model.module on each process.

torch.distributed.init_process_group("nccl", init_method="env://")
local_rank = int(os.environ.get("LOCAL_RANK"))
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank)  # Set up the device for this process.
torch.cuda.empty_cache()  

device = torch.cuda.current_device()
model = net().to(device)
ddp_model = DistributedDataParallel(model, device_ids=[local_rank])

ddp_model.module.load_state_dict(
                torch.load(model_ckpt_path, map_location=torch.device('cuda', device)))

torch.distributed.barrier()  #waiting for all processes to complete the ckpt loading.

Is there any difference in the result of them? I personally preferred with second way, since I can wrap the load_ckpt method in an trainer class.

1 Like

In the document of DistributedDataParallel, there’s a warning says:

You should never try to change your model’s parameters after wrapping up your model with DistributedDataParallel . Because, when wrapping up your model with DistributedDataParallel , the constructor of DistributedDataParallel will register the additional gradient reduction functions on all the parameters of the model itself at the time of construction. If you change the model’s parameters afterwards, gradient reduction functions no longer match the correct set of parameters.

So I think the first way should be the safer way.

I just came across this post and for everyone else reading this, I’d like to point out that the warning of ‘dont change model parameters after DDP wrapping’ does NOT mean you shouldn’t load a checkpoint afterwards. It means that you should not change the models architecture, i.e. remove or add modules with new params etc. because they will not be tracked for gradient reduction in that case.

Regarding checkpoint loading, both variants are fine. In fact, the reference implementation of torchvision does the second variant, have a look here. The difference between the two variants should be the following:

TLDR: same result

  1. init model → load checkpoint → wrap DDP: Depending on seeding, every rank inits similar or different random weights, all ranks load the same checkpoint, weights from rank 0 wil be synced to the other ranks.
  2. init model → wrap DDP → load checkpoint: Depending on seeding, every rank inits similar or different random weights, weights from rank 0 wil be synced to the other ranks, all ranks load the same checkpoint.

Hope this helps

3 Likes