How to train a generator and two discriminators simultaneously - How to relate second discriminator?

Hi everyone,

I want to train a GAN which has a generator and two discriminators and the discriminators have different structures but the same loss function. But I am a little confused how to deal with the second discriminator.

Here is what I have done briefly:

x,y = # input and ground truth

details_net.train() # generator
disc_one.train()
disc_two.train()

# train generator
details_optim.zero_grad()
gen_imgs = details_net(x)
g_loss = criterion(disc_one(gen_imgs), valid)
g_loss.backward(retain_graph=True)
details_optim.step()


# train discriminator one
disc_one_optim.zero_grad()
real_loss = criterion(disc_one(y), valid)
fake_loss = criterion(disc_one(gen_imgs), fake)
disc_one_loss = (real_loss + fake_loss) / 2
disc_one_loss.backward(retain_graph=True)
disc_one_optim.step()


# train discriminator two
disc_two_optim.zero_grad()
real_loss = criterion(disc_two(y, valid)
fake_loss = criterion(gen_imgs, fake)
disc_two_loss = (real_loss + fake_loss) / 2
disc_two_loss.backward()
disc_two_optim.step()

The above code is a the DataLoader loop.
I am using nn.MSE loss function for both discriminators and generators.

And my question is, am I using .backward and criterion correctly?

Thank you for your help

  1. You could dofor p in XXX.parameters(): p.requires_grad_(False) for the bits you are not training and set them to True for the bits you are training. Most GAN examples do that, it saves getting and storing unneeded gradients.
  2. You could have g_loss1 and g_loss2 for each discriminator and g_loss = alpha * g_loss1 + beta * g_loss2. To mix both criteria.
  3. You likely don’t need retain_graph=True in the g_loss.backward if you just detach (after training the generator) or regenerate the fake image. A while ago, most GANs had several disc steps per generator step. I would not know if that changed. So it would be re-generating mostly.
  4. Instead of retain_graph=True in the disc1 training, you might just add disc_one_loss and disc_two loss and do a single backward.
  5. You didn’t ask about that, but it is very unlikely that you want disc_..._loss = real_loss + fake_loss.
    If the generator minimizes criterion, then you probably want it to be real_loss - fake_loss, so that the discriminators push up the criterion for fake images.

Best regards

Thomas

1 Like