What is the best practice for running distributed adversarial training?

Hi all,

I’m just a newbie to PyTorch and struggling for PyTorch distributed training. Currently, I’m trying to implement a GAN like training strategy. The training consists of two stages:

  1. Fix task network, train discrinmator, my workflow is as following:
    src_data -> T() ->detach()-> D() -> loss(src_pred, src_label)
    tgt_data -> T()->detach()->D()->loss(tgt_pred, tgt_label)

  2. Fix discrinmator, train task network, my workflow is as following:
    src_data->T()->supervised_loss
    tgt_data->T()->D()->-1*loss(tgt_pred, tgt_label)
    The task network T() and discriminator network D() are both wrapped in DDP and they are placed in different process group. The task network is trained with supervised loss with labeled data and finetuned by the adversarial loss with unlabeled data.

For this setting I have 2 questions:

  1. Is it the correct way to combine two DDP models? Or do I have to warp them into one single module first and then place them under DDP?
  2. During training process of task network, I have to fix discriminator’s parameters. Now I just set the requires_grad of all parameters in discrinmator as False and turn them back to True after the loss.backward() is called. Is there anything else to be changed? I found DDP doesn’t allow unused parameters now, but it seems okay to use a module which doesn’t require gradients entirely. Do I do it in a correct way?

I’ll appreciate if there’s somebody could tell me what’s the best practice of implementing multiple models for adversarial training. Thanks in advance!

I have opened a discussion here about a similar question regarding two DDP modules in a GAN setting Calling DistributedDataParallel on multiple Modules?. I’m still trying to determine if one process group can suffice, but it seems like the safest course of action is to use separate groups for G and D.

Regarding setting requires_grad to False on D while back propagating G’s loss, I have been meaning to implement that same thing but never got around to it. It seems like the logical approach, as it is just wasting compute time calculating gradients for D when they are going to be discarded.

I think that doing exactly that would make it work with a single process group. Because you no longer race the allreduce calls from the two models. Also, I think you could put the discriminator in eval mode when doing this, which side steps some of the synchronization code paths in DistributedDataParallel.