Hi everyone,
I want to train a GAN which has a generator and two discriminators and the discriminators have different structures but the same loss function. But I am a little confused how to deal with the second discriminator.
Here is what I have done briefly:
x,y = # input and ground truth
details_net.train() # generator
disc_one.train()
disc_two.train()
# train generator
details_optim.zero_grad()
gen_imgs = details_net(x)
g_loss = criterion(disc_one(gen_imgs), valid)
g_loss.backward(retain_graph=True)
details_optim.step()
# train discriminator one
disc_one_optim.zero_grad()
real_loss = criterion(disc_one(y), valid)
fake_loss = criterion(disc_one(gen_imgs), fake)
disc_one_loss = (real_loss + fake_loss) / 2
disc_one_loss.backward(retain_graph=True)
disc_one_optim.step()
# train discriminator two
disc_two_optim.zero_grad()
real_loss = criterion(disc_two(y, valid)
fake_loss = criterion(gen_imgs, fake)
disc_two_loss = (real_loss + fake_loss) / 2
disc_two_loss.backward()
disc_two_optim.step()
The above code is a the DataLoader
loop.
I am using nn.MSE
loss function for both discriminators and generators.
And my question is, am I using .backward
and criterion
correctly?
Thank you for your help