I am working with a network made of two models:
- Model1: Data Parallel model parallelized with DDP
- Model2: Model Parallel model (huge weight matrix) parallelized manually with a sub-part on each DDP process/GPU
Model1 can be easily saved from any process as it is identical on each GPU.
But, Model2 is distributed/split across GPUs and must be synchonized somehow.
Question : what would be an elegant way to save Model2 ?
- gathering of Model2 must be done on CPU due to its size
- distributed.gather() is not available with NCCL backend anyways
- I could save each part on disk in each process, wait in rank0 process (distributed.barrier()), reload everything on CPU, merge and save in dict, but…