Modifying DDP model after calling DistributedDataParallel()

Assume that a pytorch module has to be modified after some DDP training on a single machine with multiple GPUs, say, a linear classification head has to be replaced/re-initialized (if some conditions hold) and then trained some more.

What would be the recommended way to do this in PyTorch?
When the new head in the example is replaced, it is done on all of the processes separately implying a totally differently initialized head in each process.
I currently call DistributedDataParallel() on the entire model again after the modification and continue training. I’m not sure if this is the correct approach or even sure what happens in my scenario.
I think that upon the second call to DistributedDataParallel() the model on rank 0 is duplicated to all of the other processes resulting with proper training.

Any advice on the correct approach?

@mrshenli any suggestion? Thanks!

The safest way of doing that would be to clone/re-create the original model, copy all the parameter from the DDP model into the clone, and then wrap the clone with DDP. (It may work if you skip some of these steps but unless perf / memory is an issue, I’d just be on the safe side. E.g.):

original_model = MyModel()
initial_ddp_model = DistributedDataParallel(original_model)
train(initial_ddp_model)

new_model = MyModel()
new_model.load_state_dict(initial_ddp_model.module.state_dict())
new_ddp_model = DistributedDataParallel(new_model)
new_optimizer =…