When we try to implement GAN training code, we should do backward for discriminator twice for real image case and fake image case. However we don’t use retain_graph=True for discriminator backward. As I know we should use retain_graph=True to prevent removing the computational graph but my code (also pytorch official dcgan tutorial code) works without retain_graph=True option. Do you know how this works?
retain_graph=True will keep the computation graph alive and would allow you to call
backward and thus calculate the gradients multiple times.
The discriminator is trained with different inputs, in the first step
netD will get the
real_cpu inputs and the corresponding gradients will be computed afterwards using
errD_real.backward(). The computation graph is freed.
In the next step the generator will create a fake image which is then passed to the discriminator as
netD(fake.detach()). A new computation graph is created and the corresponding gradients will be computed afterwards and accumulated to the
netD’s parameter will be updated via
There is thus no need to keep any computation graphs alive as they are not reused.
@ptrblck Oh thank you so much! You’ve solved a problem I’ve been struggling with for hours. Now I clearly understand!
@ptrblck I have additional question. Here we detached to prevent removing the computational graphs for training generator. However here we feed forward again and doesn’t this making new computational graphs for generator whether we detach or not?
Is this because we focus on
fake? Which are related to
netG and it feed forward only once here?
We are detaching
netD(fake.detach()) since we only want to train the discriminator to learn how to detect fake input images. Note that the
label is also filled with
errD_fake.backward() call will only compute the gradients for parameters of
netD, since the computation graph in
netG used to create
fake is detached from the previous forward pass.
The second usage via
output = netD(fake) keeps
fake attached to
netG since now we are training the generator. The
label tensor is filled with
errG.backward() will now compute the gradients for parameters in the generator
netG (as well as
netD, but we will zero them out and won’t use them). This step allows the generator to learn how to “fool” the discriminator in “thinking” the
fake input is a real image.
@ptrblck I see. We detach
fake since discriminator should train just to discriminate whether the image is fake or real but not how generator make image looks realistic.
So we detach
fake because we don’t need to track generators parameter while training discriminator. But also if we don’t detach, still there’s nothing matter for discriminator to train (maybe just some performance isses) but it matters for generator. Does this due to
fake have already been backward here if we don’t detach so that we can no longer use this computational graph to train generator unless we feed forward again to create new
You don’t want to reuse the same computation graph since the generator would be trained to help the discriminator detect fake images. Take another look at the losses and how the targets are defined to train the discriminator vs. the generator.
Oh that’s right. Now I got your point. I really appreciate your help!