GAN loss is not changing, has gradients but is stuck

Hello,

I am new to GAN models and I have hit a weird problem.

The model does not seem to be updating properly. The discriminator and generator BCE loss stay at approximately 0.69 the entire time, its almost as if the two optimizers are not working. I have been banging my head against this problem for three days now so any help would be very appreciated.

Code:

##### Print #####
model = GAN_Generator(run)
disc = GAN_Discriminator(run)
printNetworkSummary(disc, False)
printNetworkSummary(model, False)
device = "cuda"

data_loader = torch.utils.data.DataLoader(train_set, batch_size=run.batch, shuffle=run.shuf)
val_data_loader = torch.utils.data.DataLoader(val_set, batch_size=run.batch, shuffle=run.shuf)

model.to(device)
disc.to(device)
wandb.watch(model, log='all')

optimizer = getOptimizer(model.parameters(), run)
disc_optimizer = getOptimizer(disc.parameters(), run)
        
# Epoch loop  
for epoch in range(run.E_num):
            model.train()
            disc.train()
            lossy = 0
            # Training Loop ###########################################################################################################
            for image_batch, labels in data_loader:
                disc_optimizer.zero_grad()
                optimizer.zero_grad()
                
                label_real = torch.ones(image_batch.size(0), device=dev)
                label_fake = torch.zeros(image_batch.size(0), device=dev)

                image_batch = image_batch.to(device)
                labels = labels.to(device).float()
                
                recon_x = model(outputDict, image_batch)

                real_pred = disc(outputDict, image_batch)
                fake_pred = disc(outputDict, recon_x.detach())

                disc_lossy = 0.5 * (F.binary_cross_entropy(real_pred, label_real) + F.binary_cross_entropy(fake_pred, label_fake))
                disc_lossy.backward(retain_graph=True)
                disc_optimizer.step()

                fake_pred = disc(outputDict, recon_x)
                lossy = F.binary_cross_entropy(fake_pred, label_real)

                lossy.backward()              
                optimizer.step()

Models:

The discriminator model is simply a set of convolution relus and batchnorms ending in a linear classifier with a sigmoid activation.

The generator model is actually a convolutional autoencoder which also ends in a sigmoid activation.

(note I am using the F.binary_cross_entropy loss which plays nice with sigmoids)
Tests:

I have messed around with batchsize and step size and have not gotten any improvement. Generally, the loss initially goes from .7 to .69 and then oscillates between 0.685 and 0.695, with no improvement in the generated image.

To make sure the parameters were updating at all I used:

a = list(disc.parameters())[0].clone()
disc_lossy.backward(retain_graph=True)
disc_optimizer.step()
b = list(disc.parameters())[0].clone()
print((a == b))

This printed false, so the parameters are updating.

I then used this function to check the gradients.

def checkGrad(model, loss):
    print("loss grad : ", loss.grad)
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(name, param.grad.sum())
        else:
            print(name, param.grad)

This printed:
loss grad : 1
conv1.conv.weight tensor(0.3062, device=‘cuda:0’)
conv1.conv.bias tensor(0.1214, device=‘cuda:0’)
conv1.norm.weight tensor(-0.0040, device=‘cuda:0’)
conv1.norm.bias tensor(-0.0181, device=‘cuda:0’)
…etc.

That all looks fine to me, so I am confused by what is going wrong.

Once again any help would be greatly appreciated, I have found nothing online that matches my problem. Thanks

Ahhhhh finally figured it out. For anyone else with a similar problem, it was because I had retain_graph = True on the first backward function. In other contexts that is useful with multiple losses/models but not in this instance.

Actually sneaky, sneaky… there was another problem. This one I do not understand. With batchNorm on each layer, the Discriminator loss gets stuck at ~0.69. I do not know why but when I switch my normalization method to dropout everything starts working.

This person has had a similar experience: