Generator not convering?

Hi guys! I’m using this code to inpaint missing parts of the image:

disc = Discriminator(
    latent_vector_size = LATENT_VECTOR_SIZE,
    features_d = IMAGE_SIZE,
    num_channels = CHANNELS
    )

disc.build()
disc.apply(Discriminator.init_weights)

disc.define_optim(
    learning_rate = LEARNING_RATE,
    beta1 = BETA1
    )

gen = Generator(
    features_g = IMAGE_SIZE, 
    num_channels = CHANNELS
    )

gen.build()
gen.apply(Generator.init_weights)

gen.define_optim(
    learning_rate = LEARNING_RATE,
    beta1 = BETA1
    )

gen.to(device)
disc.to(device)
adversarial_loss = nn.MSELoss()
for epoch in range(EPOCHS):
    for batch_num, data in enumerate(train_dataloader, 0):   
        real_image = data[0].to(device)
        masked_image = masker(data)
        resized_image = resizer(data)
        
        masked_image = masked_image.to(device)
        resized_image = resized_image.to(device)
        
        valid = torch.full((BATCH_SIZE, 1, 8, 8), real_label, dtype = torch.float, device = device)
        fake = torch.full((BATCH_SIZE, 1, 8, 8), fake_label, dtype = torch.float, device = device)
        
        ## TRAIN GENERATOR ##
        gen.zero_grad()
        
        generated_images = gen(masked_image, resized_image)
        generator_loss = adversarial_loss(disc(generated_images), valid)  
        generator_loss.backward()
        
        gen.optimizer.step()
        
        ## TRAIN DISCRIMINATOR ##
        
        disc.zero_grad()
        
        real_loss = adversarial_loss(disc(real_image), valid)
        fake_loss = adversarial_loss(disc(generated_images.detach()), fake)
        
        discriminator_loss = 0.5 * (real_loss + fake_loss)
        discriminator_loss.backward()
        
        disc.optimizer.step()
        
        
        if batch_num % 50 == 0:
            print(
                f'[{epoch + 1}/{EPOCHS}][{batch_num}/{len(train_dataloader)}]  ' 
                f'D_Loss : {round(discriminator_loss.item(),4)}  '    
                f'G_Loss : {round(generator_loss.item(),4)}'  
                )
            
            with torch.no_grad():
                        img_grid_fake = torchvision.utils.make_grid(generated_images[:24], normalize = True)
                        img_grid_real = torchvision.utils.make_grid(real_image[:24], normalize = True)
                        if batch_num == 0 and epoch == 0:
                            img_grid_mask = torchvision.utils.make_grid(masked_image[:24], normalize = True)
                            writer_mask.add_image("Mask", img_grid_mask, global_step = step)
                            img_grid_resi = torchvision.utils.make_grid(resized_image[:24], normalize = True)
                            writer_resi.add_image("Resi", img_grid_resi, global_step = step)
                                
                        writer_real.add_image("Real", img_grid_real, global_step = step)
                        writer_fake.add_image("Fake", img_grid_fake, global_step = step)
            step =  step + 1
        
        if batch_num % 100 == 0:
            G_losses.append(generator_loss.item())
            D_losses.append(discriminator_loss.item())
            
        batches_done = epoch * len(train_dataloader) + batch_num
        if batches_done % 5000 == 0:
            gen.eval()
            gen.train()

(if needed I’ll post the Generator and Discriminator models themselves as well).

However, the loss values turn up like this:

The discriminator converges relatively fast, but the generator is all around the place. The GAN is CCGAN based on this article.

Any tips? Thanks!

Correct me if I am wrong. In your code you train generator and discriminator equally, right? I think Generator is lagging too far from the Discriminator.

In that case, you can try to optimize the Generator network more often than the D.

You may also borrow something from this link.

1 Like

If by equally you mean equally often,

then yes, that would be the case. Thanks for the suggestion. I’m trying a few runs with Discriminator’s learning rate lowered to see if that does anything (would that be equivalent to training the generator more often than dicrsiminator?). I’ll try your tip as well! Thanks!

Even if you train with a small LR, the target for the generator varies. The original GAN paper proposed to keep the D near-optimal by training “k” times more than the generator.

You might want to give some time for the D to settle up and let G catch up slowly instead of in one sweep.

1 Like

Thanks for the advice! So that would mean I just need more epochs and the generator will eventually catch up to discriminator?