GAN: Calculate the loss functions


I’m new with pytorch (and also with GAN), and I need to compute the loss functions for both the discriminator and the generator.

The code is standard:

import torch.nn as nn
import torch.nn.functional as F

# Choose a value for the prior dimension
PRIOR_N = 25

# Define the generator
class Generator(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(PRIOR_N, 2)
        self.fc2 = nn.Linear(2, 2)
    def __call__(self, z):
        h = F.relu(self.fc1(z))
        return self.fc2(h)
    def generate(self, batchlen):
        z = torch.normal(torch.zeros(batchlen, PRIOR_N), 1.0)
        return self.__call__(z)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 1)
    def __call__(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h)
# Number of times to train the discriminator between two generator steps
# Total number of training iterations for the generator
N_ITER = 20001
# Batch size to use

generator = Generator()
optim_gen = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5,0.9))

discriminator = Discriminator()
optim_disc = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5,0.9))

for i in range(N_ITER):
    # train the discriminator
    for _ in range(TRAIN_RATIO):
        # zero_grad() clear the gradient of all optimized torch tensor
        real_batch = generate_batch(BATCHLEN)
        # Create the real points
        fake_batch = generator.generate(BATCHLEN)
        # Create the fake points
        # Compute here the discriminator loss, using functions like torch.sum, torch.exp, 
        # torch.log, torch.softplus, using real_batch and fake_batch
        disc_loss = 0 # FILL HERE
    # train the generator
    fake_batch = generator.generate(BATCHLEN)
    # Compute here the generator loss, using fake_batch
    gen_loss = 0 # FILL HERE
    if i%100 == 0:
        print('step {}: discriminator: {:.3e}, generator: {:.3e}'.format(i, float(disc_loss), float(gen_loss)))
        # plot the result
        real_batch = generate_batch(1024)
        fake_batch = generator.generate(1024).detach()
        plt.scatter(real_batch[:,0], real_batch[:,1], s=2.0, label='real data')
        plt.scatter(fake_batch[:,0], fake_batch[:,1], s=2.0, label='fake data')

So, I can’t use pre built functions, but have to use:
For the discriminator loss: functions like torch.sum, torch.exp, torch.log, torch.softplus, using real_batch and fake_batch
For the generator loss: fake_batch

The thing I’m pretty sure of, is that first, I have to create two new variables:
disc_real = discriminator(real_batch)
disc_fake = discriminator(fake_batch)

but then, I don’t really know what to do. I thought about beginning with something like:
disc_loss = torch.log(disc_real) + torch.log(1. - disc_fake)

Thanks for your help

This should answer your question:

class WGANLoss(torch.nn.Module):

    def __init__(self):
        super(WGANLoss, self).__init__()

    def forward(self,predicted,true):
        WGAN = (-2*true+1) * torch.mean(predicted)
        return WGAN
def forward_D(latent_code,Img_real):
    Img_fake = G(latent_code)
    d_real = D(Img_real)
    d_fake = D(Img_fake.detach())
    return d_real,d_fake
def forward_G(latent_code):
    Img_fake = G(latent_code)
    d_fake = D(Img_fake)
    return d_fake

def backward_G(d_fake,optim_G):
    g_loss = compute_G_loss(d_fake)

def backward_D(d_real,d_fake,optim_D):
    d_loss = compute_D_loss(d_real,d_fake)

def compute_G_loss(d_fake):
    g_adv_loss = compute_adv_loss(d_fake, True)
    return g_adv_loss

def compute_D_loss(d_real,d_fake):
    d_adv_loss_real = compute_adv_loss(d_real, True)
    d_adv_loss_fake = compute_adv_loss(d_fake, False)
    d_adv_loss = (d_adv_loss_real + d_adv_loss_fake)

    return d_adv_loss

	Define here latent_codes, Imgs, G (Generator) and D (Discriminator)
compute_adv_loss = WGAN()

optim_G = optim.Adam(G.parameters())
optim_D = optim.Adam(D.parameters())

for latent_code,Img_real in zip(latent_codes,Imgs):

	# ===update D===
	d_real,d_fake = forward_D(latent_code,Img_real)

	# ===update G===
	d_fake = forward_G(latent_code)

1 Like

Thanks for your answer.
I would prefer avoid using pre built functions like WGAN() (also, i’m not sure how this one works).

Is it possible to simply use the functions my professor asks me to use in the code?

Thanks !!