Issue with WGAN with gradient penalty model - Negative losses for the discriminator and generator

I’m new to GANs, I have been trying to train a WGAN on 3d micro-CT images with one channel of shape (H, W, D). However, I got the discriminator and generator loss to be negative values. Can you please point out the reason for that. I have provided the discriminator and generator architectures and the training loop.

Discriminator(
  (main): Sequential(
    (0): Conv3d(1, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (3): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv3d(128, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (6): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv3d(256, 512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (9): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv3d(512, 1, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
  )
)
Generator(
  (main): Sequential(
    (0): ConvTranspose3d(100, 512, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
    (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose3d(512, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (7): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose3d(128, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (10): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose3d(64, 1, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (13): Tanh()
  )
)
import torchvision.utils as vutils

CRTIC_ITERATIONS = 5
LAMBDA_GP = 10

# Number of training epochs
num_epochs = 5
dataloader = train_loader

# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0



# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: min -[Disc(real) - Gen(fake)]
        ###########################
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0) # this provide the size (number of images ) per batch
        for _ in range(CRTIC_ITERATIONS):
            netD.zero_grad()
            noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
            netD_real_output = netD(real_cpu).view(-1)
            fake = netG(noise).to(device)
            netD_fake_output = fake.detach().view(-1)
            gp = gradientpenalty(netD, data[0], fake, device = device)
            errD = (
                -(torch.mean(netD_real_output) - torch.mean(netD_fake_output)) + gp * LAMBDA_GP)
            errD.backward()
            optimizerD.step()            
        
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
        fake = netG(noise)
        errG_fake_output = netD(fake).view(-1) # notice here we didnt detach the fake tensor as we want the weights of the generator to be updated ased on the outcome from the discriminator
        # Calculate G's loss based on this output
        errG = - torch.mean(errG_fake_output)
        # Calculate gradients for G
        errG.backward()

        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print(f"Epoch [ {epoch}/{num_epochs}] Batch {i}/{len(dataloader)} \
                    loss D: {errD:.4f}, loss G: {errG:.4f}")

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
       # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(fake)

        iters += 1

Starting Training Loop...
Epoch [ 0/5] Batch 0/11                     loss D: 5045.0757, loss G: 1.1221
Epoch [ 1/5] Batch 0/11                     loss D: 579.6640, loss G: -9.8872
Epoch [ 2/5] Batch 0/11                     loss D: 61.9519, loss G: -19.0555
Epoch [ 3/5] Batch 0/11                     loss D: -14.0542, loss G: -35.1635

Oh WGANs, reminding me of the good old times of early PyTorch!

I think negative losses may legitimately happen here, but reason is a bit elaborate:

The WGAN approximates the 1-Wasserstein distance in the Kantorovich-Rubinstein dual formulation where you take a supremum over all admissible test functions.

The discriminator acts as a test function, and the WGAN training gradient ascend steps over aproximations to the integral through a sum of samples to approximate the supremum (with the gradient penalty trying to enforce the admissibility condition for the test functions). While the true supremum will lead to a positive value (or it would not be a distance we’re talking about) for non-identical distributions, both the sampling and the non-optimality of the current test function (the discriminator) may lead to negative values.

Best regards

Thomas