GAN: Calculate the loss functions

Hello,

I’m new with pytorch (and also with GAN), and I need to compute the loss functions for both the discriminator and the generator.

The code is standard:

import torch.nn as nn
import torch.nn.functional as F

# Choose a value for the prior dimension
PRIOR_N = 25

# Define the generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(PRIOR_N, 2)
        self.fc2 = nn.Linear(2, 2)
        
    def __call__(self, z):
        h = F.relu(self.fc1(z))
        return self.fc2(h)
    
    def generate(self, batchlen):
        z = torch.normal(torch.zeros(batchlen, PRIOR_N), 1.0)
        return self.__call__(z)


# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 1)
    
    def __call__(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h)
# Number of times to train the discriminator between two generator steps
TRAIN_RATIO = 1
# Total number of training iterations for the generator
N_ITER = 20001
# Batch size to use
BATCHLEN = 128

generator = Generator()
optim_gen = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5,0.9))

discriminator = Discriminator()
optim_disc = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5,0.9))

for i in range(N_ITER):
    # train the discriminator
    for _ in range(TRAIN_RATIO):
        discriminator.zero_grad()
        # zero_grad() clear the gradient of all optimized torch tensor
        real_batch = generate_batch(BATCHLEN)
        # Create the real points
        fake_batch = generator.generate(BATCHLEN)
        # Create the fake points
        
        # Compute here the discriminator loss, using functions like torch.sum, torch.exp, 
        # torch.log, torch.softplus, using real_batch and fake_batch
        
        
          
        disc_loss = 0 # FILL HERE
        disc_loss.backward()
        optim_disc.step()
    # train the generator
    generator.zero_grad()
    fake_batch = generator.generate(BATCHLEN)
    # Compute here the generator loss, using fake_batch
    gen_loss = 0 # FILL HERE
    gen_loss.backward()
    optim_gen.step()
    if i%100 == 0:
        print('step {}: discriminator: {:.3e}, generator: {:.3e}'.format(i, float(disc_loss), float(gen_loss)))
        # plot the result
        real_batch = generate_batch(1024)
        fake_batch = generator.generate(1024).detach()
        plt.scatter(real_batch[:,0], real_batch[:,1], s=2.0, label='real data')
        plt.scatter(fake_batch[:,0], fake_batch[:,1], s=2.0, label='fake data')
        plt.show()

So, I can’t use pre built functions, but have to use:
For the discriminator loss: functions like torch.sum, torch.exp, torch.log, torch.softplus, using real_batch and fake_batch
For the generator loss: fake_batch

The thing I’m pretty sure of, is that first, I have to create two new variables:
disc_real = discriminator(real_batch)
disc_fake = discriminator(fake_batch)

but then, I don’t really know what to do. I thought about beginning with something like:
disc_loss = torch.log(disc_real) + torch.log(1. - disc_fake)

Thanks for your help

This should answer your question:


class WGANLoss(torch.nn.Module):

    def __init__(self):
        super(WGANLoss, self).__init__()

    def forward(self,predicted,true):
        WGAN = (-2*true+1) * torch.mean(predicted)
        return WGAN
       
def forward_D(latent_code,Img_real):
    Img_fake = G(latent_code)
    d_real = D(Img_real)
    d_fake = D(Img_fake.detach())
    
    return d_real,d_fake
    
def forward_G(latent_code):
    Img_fake = G(latent_code)
    d_fake = D(Img_fake)
    
    return d_fake


def backward_G(d_fake,optim_G):
    g_loss = compute_G_loss(d_fake)
    g_loss.backward()
    optim_G.step()


def backward_D(d_real,d_fake,optim_D):
    d_loss = compute_D_loss(d_real,d_fake)
    d_loss.backward(retain_graph=False)
    optim_D.step()


def compute_G_loss(d_fake):
    g_adv_loss = compute_adv_loss(d_fake, True)
    
    return g_adv_loss

def compute_D_loss(d_real,d_fake):
    d_adv_loss_real = compute_adv_loss(d_real, True)
    d_adv_loss_fake = compute_adv_loss(d_fake, False)
    
    d_adv_loss = (d_adv_loss_real + d_adv_loss_fake)

    return d_adv_loss
    

'''
	Define here latent_codes, Imgs, G (Generator) and D (Discriminator)
'''
compute_adv_loss = WGAN()

optim_G = optim.Adam(G.parameters())
optim_D = optim.Adam(D.parameters())

for latent_code,Img_real in zip(latent_codes,Imgs):

	# ===update D===
	optim_D.zero_grad()
	d_real,d_fake = forward_D(latent_code,Img_real)
	backward_D(d_real,d_fake,optim_D)

	# ===update G===
	optim_G.zero_grad()
	d_fake = forward_G(latent_code)
	backward_G(d_fake,optim_G)


1 Like

Thanks for your answer.
I would prefer avoid using pre built functions like WGAN() (also, i’m not sure how this one works).

Is it possible to simply use the functions my professor asks me to use in the code?

Thanks !!