Retain_graph for GAN discriminator

When we try to implement GAN training code, we should do backward for discriminator twice for real image case and fake image case. However we don’t use retain_graph=True for discriminator backward. As I know we should use retain_graph=True to prevent removing the computational graph but my code (also pytorch official dcgan tutorial code) works without retain_graph=True option. Do you know how this works?

Using retain_graph=True will keep the computation graph alive and would allow you to call backward and thus calculate the gradients multiple times.
The discriminator is trained with different inputs, in the first step netD will get the real_cpu inputs and the corresponding gradients will be computed afterwards using errD_real.backward(). The computation graph is freed.
In the next step the generator will create a fake image which is then passed to the discriminator as netD(fake.detach()). A new computation graph is created and the corresponding gradients will be computed afterwards and accumulated to the .grad attributes.
Finally, netD’s parameter will be updated via optimizerD.step() here.

There is thus no need to keep any computation graphs alive as they are not reused.

1 Like

@ptrblck Oh thank you so much! You’ve solved a problem I’ve been struggling with for hours. Now I clearly understand!

1 Like

@ptrblck I have additional question. Here we detached to prevent removing the computational graphs for training generator. However here we feed forward again and doesn’t this making new computational graphs for generator whether we detach or not?

Is this because we focus on fake? Which are related to netG and it feed forward only once here?

We are detaching fake in netD(fake.detach()) since we only want to train the discriminator to learn how to detect fake input images. Note that the label is also filled with fake_label.
The errD_fake.backward() call will only compute the gradients for parameters of netD, since the computation graph in netG used to create fake is detached from the previous forward pass.

The second usage via output = netD(fake) keeps fake attached to netG since now we are training the generator. The label tensor is filled with real_label and errG.backward() will now compute the gradients for parameters in the generator netG (as well as netD, but we will zero them out and won’t use them). This step allows the generator to learn how to “fool” the discriminator in “thinking” the fake input is a real image.

1 Like

@ptrblck I see. We detach fake since discriminator should train just to discriminate whether the image is fake or real but not how generator make image looks realistic.

So we detach fake because we don’t need to track generators parameter while training discriminator. But also if we don’t detach, still there’s nothing matter for discriminator to train (maybe just some performance isses) but it matters for generator. Does this due to fake have already been backward here if we don’t detach so that we can no longer use this computational graph to train generator unless we feed forward again to create new fake again?

1 Like

You don’t want to reuse the same computation graph since the generator would be trained to help the discriminator detect fake images. Take another look at the losses and how the targets are defined to train the discriminator vs. the generator.

1 Like

Oh that’s right. Now I got your point. I really appreciate your help!