Assume that a pytorch module has to be modified after some DDP training on a single machine with multiple GPUs, say, a linear classification head has to be replaced/re-initialized (if some conditions hold) and then trained some more.
What would be the recommended way to do this in PyTorch?
When the new head in the example is replaced, it is done on all of the processes separately implying a totally differently initialized head in each process.
I currently call DistributedDataParallel() on the entire model again after the modification and continue training. I’m not sure if this is the correct approach or even sure what happens in my scenario.
I think that upon the second call to DistributedDataParallel() the model on rank 0 is duplicated to all of the other processes resulting with proper training.
Any advice on the correct approach?