DDP + Model Parallel : save checkpoint


I am working with a network made of two models:

  • Model1: Data Parallel model parallelized with DDP
  • Model2: Model Parallel model (huge weight matrix) parallelized manually with a sub-part on each DDP process/GPU

Model1 can be easily saved from any process as it is identical on each GPU.
But, Model2 is distributed/split across GPUs and must be synchonized somehow.

Question : what would be an elegant way to save Model2 ?


  • gathering of Model2 must be done on CPU due to its size
  • distributed.gather() is not available with NCCL backend anyways
  • I could save each part on disk in each process, wait in rank0 process (distributed.barrier()), reload everything on CPU, merge and save in dict, but…

Thanks !

I’m not sure I understand the question. Can you maybe take a look here? You don’t seem to need distributed.gather(), only distributed.barrier()?