Hi there,
I’m currently trying to train a GAN using PyTorch lightning and I want to implement multiple discriminator updates per generator update but i am unsure of the best practice.
Here is my training step:
def training_step(self, batch):
optG, optD = self.optimizers()
# data and real/fake labels
real_data, coords = batch
real_labels = torch.full((real_data.size(0),), 1.0, dtype=torch.float).type_as(
real_data
)
fake_labels = torch.full((real_data.size(0),), 0.0, dtype=torch.float).type_as(
real_data
)
# Generate fake-data using noise input
input = torch.tensor([])
fake_data = self.model.generator(input).type_as(real_data).reshape(-1, 1)
# Configure Samples
disc_value = self.model.discriminator(coords)
# Training the generator
self.toggle_optimizer(optG)
optG.zero_grad()
errG = self.adversarial_loss(disc_value, real_labels, fake_data)
self.log("train_g_loss_step", errG, prog_bar=True)
self.manual_backward(errG, retain_graph=True)
optG.step()
self.untoggle_optimizer(optG)
# Training the discriminator
self.toggle_optimizer(optD)
optD.zero_grad()
errD_real = self.adversarial_loss(
disc_value, real_labels, real_data
) # Discriminator real loss
errD_fake = self.adversarial_loss(
disc_value, fake_labels, fake_data.detach()
) # Discriminator fake loss
errD = (errD_real + errD_fake) / 2
self.log("train_d_loss_step", errD, prog_bar=True)
self.manual_backward(errD)
optD.step()
self.untoggle_optimizer(optD)
I have seen two schools of thought, either write a for loop over the discriminator training aspect or in the configure optimizers set the ‘frequency’ parameter, to something like the following:
return ( {'optimizer': optD, 'frequency': 5}, {'optimizer': optG, 'frequency': 1} )
Any advice on what implementation to use in this framework would be greatly appreciated!