Get `state_dict` from `DataDistributedParallel` model while other thread is running `backward`

Hi,

I am stuck with the following problem:

I have two ranks and each rank has two threads:

  1. (Learning Thread) Runs permanently backwardon a DataDistributedParallel model
  2. (Main Thread) Retrieves from time to time the state_dict of the model to broadcast it to other independent processes.

The problem I encounter is that the processes stall. I have tried out many ways by setting barriers, using thread locks, even using torch.save to a byte stream in memory and then loading it from it again. It appears to me that this setup is somehow not possible.

Did anyone ever encounter this? Is there a solution to it?

Btw: Running the second thread without getting the state_dict runs perfectly. So I guess that this indeed a problem induced by the DDP.