How to implement multiple discriminator updates per gen update in PytorchLightning

Hi there,
I’m currently trying to train a GAN using PyTorch lightning and I want to implement multiple discriminator updates per generator update but i am unsure of the best practice.

Here is my training step:

 def training_step(self, batch):

        optG, optD = self.optimizers()

        # data and real/fake labels
        real_data, coords = batch

        real_labels = torch.full((real_data.size(0),), 1.0, dtype=torch.float).type_as(
            real_data
        )
        fake_labels = torch.full((real_data.size(0),), 0.0, dtype=torch.float).type_as(
            real_data
        )

        # Generate fake-data using noise input
        input = torch.tensor([])
        fake_data = self.model.generator(input).type_as(real_data).reshape(-1, 1)

        # Configure Samples
        disc_value = self.model.discriminator(coords)

        # Training the generator
        self.toggle_optimizer(optG)
        optG.zero_grad()
        errG = self.adversarial_loss(disc_value, real_labels, fake_data)
        self.log("train_g_loss_step", errG, prog_bar=True)
        self.manual_backward(errG, retain_graph=True)
        optG.step()
        self.untoggle_optimizer(optG)

        # Training the discriminator
        self.toggle_optimizer(optD)
        optD.zero_grad()
        errD_real = self.adversarial_loss(
            disc_value, real_labels, real_data
        )  # Discriminator real loss
        errD_fake = self.adversarial_loss(
            disc_value, fake_labels, fake_data.detach()
        )  # Discriminator fake loss
        errD = (errD_real + errD_fake) / 2
        self.log("train_d_loss_step", errD, prog_bar=True)
        self.manual_backward(errD)
        optD.step()
        self.untoggle_optimizer(optD)

I have seen two schools of thought, either write a for loop over the discriminator training aspect or in the configure optimizers set the ‘frequency’ parameter, to something like the following:

return ( {'optimizer': optD, 'frequency': 5}, {'optimizer': optG, 'frequency': 1} ) 

Any advice on what implementation to use in this framework would be greatly appreciated!

You might want to cross-post the question into the lightning discussion board, as you’ll find the lightning experts there.