When we try to implement GAN training code, we should do backward for discriminator twice for real image case and fake image case. However we don’t use retain_graph=True for discriminator backward. As I know we should use retain_graph=True to prevent removing the computational graph but my code (also pytorch official dcgan tutorial code) works without retain_graph=True option. Do you know how this works?
Using retain_graph=True
will keep the computation graph alive and would allow you to call backward
and thus calculate the gradients multiple times.
The discriminator is trained with different inputs, in the first step netD
will get the real_cpu
inputs and the corresponding gradients will be computed afterwards using errD_real.backward()
. The computation graph is freed.
In the next step the generator will create a fake image which is then passed to the discriminator as netD(fake.detach())
. A new computation graph is created and the corresponding gradients will be computed afterwards and accumulated to the .grad
attributes.
Finally, netD
’s parameter will be updated via optimizerD.step()
here.
There is thus no need to keep any computation graphs alive as they are not reused.
@ptrblck Oh thank you so much! You’ve solved a problem I’ve been struggling with for hours. Now I clearly understand!
@ptrblck I have additional question. Here we detached to prevent removing the computational graphs for training generator. However here we feed forward again and doesn’t this making new computational graphs for generator whether we detach or not?
Is this because we focus on fake
? Which are related to netG
and it feed forward only once here?
We are detaching fake
in netD(fake.detach())
since we only want to train the discriminator to learn how to detect fake input images. Note that the label
is also filled with fake_label
.
The errD_fake.backward()
call will only compute the gradients for parameters of netD
, since the computation graph in netG
used to create fake
is detached from the previous forward pass.
The second usage via output = netD(fake)
keeps fake
attached to netG
since now we are training the generator. The label
tensor is filled with real_label
and errG.backward()
will now compute the gradients for parameters in the generator netG
(as well as netD
, but we will zero them out and won’t use them). This step allows the generator to learn how to “fool” the discriminator in “thinking” the fake
input is a real image.
@ptrblck I see. We detach fake
since discriminator should train just to discriminate whether the image is fake or real but not how generator make image looks realistic.
So we detach fake
because we don’t need to track generators parameter while training discriminator. But also if we don’t detach, still there’s nothing matter for discriminator to train (maybe just some performance isses) but it matters for generator. Does this due to fake
have already been backward here if we don’t detach so that we can no longer use this computational graph to train generator unless we feed forward again to create new fake
again?
You don’t want to reuse the same computation graph since the generator would be trained to help the discriminator detect fake images. Take another look at the losses and how the targets are defined to train the discriminator vs. the generator.
Oh that’s right. Now I got your point. I really appreciate your help!