Pytorch DDP failed to resume after validation using a clone of the model

dataloader = ....(create ddp loader with ddp settings)
opts = ...parse() # user options
master = opts.local_rank == 0

model = create_model(opt)
model_ema = model.clone().eval() # keeping track of exponential moving average for model's weights

for data in dataloader():
     # typical training code ... forward, backward and the likes 
      update_ema_weights(model_ema, model.state_dict()) # update the weights for model's team
      if opt.validate:
         if master:
               for data in valid_dataloader():
                     output = model_ema(data).... # typical validate code

given the above pseudo-code, after validation, my DDP process will hang on all GPUs.
However, if I use model instead of model_ema for validation, it will not. Does anyone know how to fix this?

Would I be correct if I assume model (and model_ema as well) is a DistributedDataParallel instance? If so, the forward method of DistributedDataParallel will set some internal flags, which could cause hang.

If you just want to evaluate, you can use model.module to retrieve the original non-DDP module. And then clone, set eval(), run forward on the original non-DDP module. This should keep DDP states intact.

I see. I will give it a shot. But it doesn’t explain why if I use “model” to evaluate, no hang occur.

Good point. How did you implement the model.clone() method? IIUC, neither nn.Module nor DistributedDataParallel has a clone() method. Is this a copy.deepcopy?

i use deepcopy to clone the model.