I’m new to GANs, I have been trying to train a WGAN on 3d micro-CT images with one channel of shape (H, W, D). However, I got the discriminator and generator loss to be negative values. Can you please point out the reason for that. I have provided the discriminator and generator architectures and the training loop.
Discriminator(
(main): Sequential(
(0): Conv3d(1, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(3): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv3d(128, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(6): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv3d(256, 512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(9): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv3d(512, 1, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
)
)
Generator(
(main): Sequential(
(0): ConvTranspose3d(100, 512, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
(1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose3d(512, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(7): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose3d(128, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(10): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose3d(64, 1, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
(13): Tanh()
)
)
import torchvision.utils as vutils
CRTIC_ITERATIONS = 5
LAMBDA_GP = 10
# Number of training epochs
num_epochs = 5
dataloader = train_loader
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: min -[Disc(real) - Gen(fake)]
###########################
real_cpu = data[0].to(device)
b_size = real_cpu.size(0) # this provide the size (number of images ) per batch
for _ in range(CRTIC_ITERATIONS):
netD.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
netD_real_output = netD(real_cpu).view(-1)
fake = netG(noise).to(device)
netD_fake_output = fake.detach().view(-1)
gp = gradientpenalty(netD, data[0], fake, device = device)
errD = (
-(torch.mean(netD_real_output) - torch.mean(netD_fake_output)) + gp * LAMBDA_GP)
errD.backward()
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, 1, device=device)
fake = netG(noise)
errG_fake_output = netD(fake).view(-1) # notice here we didnt detach the fake tensor as we want the weights of the generator to be updated ased on the outcome from the discriminator
# Calculate G's loss based on this output
errG = - torch.mean(errG_fake_output)
# Calculate gradients for G
errG.backward()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print(f"Epoch [ {epoch}/{num_epochs}] Batch {i}/{len(dataloader)} \
loss D: {errD:.4f}, loss G: {errG:.4f}")
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(fake)
iters += 1
Starting Training Loop...
Epoch [ 0/5] Batch 0/11 loss D: 5045.0757, loss G: 1.1221
Epoch [ 1/5] Batch 0/11 loss D: 579.6640, loss G: -9.8872
Epoch [ 2/5] Batch 0/11 loss D: 61.9519, loss G: -19.0555
Epoch [ 3/5] Batch 0/11 loss D: -14.0542, loss G: -35.1635