Hi there,
I’m new to pytorch and I’d like to implement a GAN model, I see there’s a tutorial DCGAN Tutorial — PyTorch Tutorials 1.12.1+cu102 documentation, it uses two optimizer to backward the parameters of generator and discriminator separately, so that the gradient of the generator loss function is not back propagated into the discriminator parameters, and the gradient of the discriminator loss function is not back propagated into the generator parameters.
But what if I want to build a more complicated model which need to share some lower layers between generator and discriminator, I can’t put the parameters of the generator and the discriminator into different optimizers because they share many parameters. How can I train such a model correctly and efficiently?
Here’s the sample network structure:
class Generator(nn.Module):
self.encoder: shared network between generator and discriminator
self.top: linear layerdef forward(self, input): return self.top(self.encoder(input))
class Discriminator(nn.Module):
self.encoder: shared network between generator and discriminator
self.top: linear layerdef forward(self, input): return self.top(self.encoder(input))
netG = Generator()
netD = Discriminator()training loop:
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batchoutput = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(netD(real_img), real_label)## Train with all-fake batch # Calculate D's loss on the all-fake batch errD_fake = criterion(netD(netG(noise).detach()), fake_label) # Calculate gradients for D in backward pass errD = errD_real + errD_fake errD.backward() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1) # Calculate G's loss based on this output errG = criterion(netD(netG(noise)), fake_label) #!!! We don't need the gradient of netD with respect to errG, but we still need the gradient of netG with respect to errG and the gradient of netD with respect to errD above, so how to skip back propagation of netD for this loss(under the setting that netG and netD share some lower layers)? # Calculate gradients for G errG.backward()