DDP Training with network of dynamic depth

Hi, I’m training a self-pruning network on DDP, i.e. the input of the network can go through different paths towards the output. However, in DDP, this seems to result in an error in my case:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

After I set find_unused_parameters=True, the above error goes away, but it seems like some of the parameters in prunable layers fail to synchronize (parameters in different GPUs are not the same one) after the update of one training step.

The desired behavior of the training is: as long as some of the processes in GPUs used the parameter, it always performs the update for this parameter across processes in ALL GPUs. Is there any way to make the training behavior possible with DDP? Thanks!