# GAN: Calculate the loss functions

Hello,

I’m new with pytorch (and also with GAN), and I need to compute the loss functions for both the discriminator and the generator.

The code is standard:

``````import torch.nn as nn
import torch.nn.functional as F

# Choose a value for the prior dimension
PRIOR_N = 25

# Define the generator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(PRIOR_N, 2)
self.fc2 = nn.Linear(2, 2)

def __call__(self, z):
h = F.relu(self.fc1(z))
return self.fc2(h)

def generate(self, batchlen):
z = torch.normal(torch.zeros(batchlen, PRIOR_N), 1.0)
return self.__call__(z)

# Define the discriminator
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 1)

def __call__(self, x):
h = F.relu(self.fc1(x))
return self.fc2(h)
``````
``````# Number of times to train the discriminator between two generator steps
TRAIN_RATIO = 1
# Total number of training iterations for the generator
N_ITER = 20001
# Batch size to use
BATCHLEN = 128

generator = Generator()

discriminator = Discriminator()

for i in range(N_ITER):
# train the discriminator
for _ in range(TRAIN_RATIO):
real_batch = generate_batch(BATCHLEN)
# Create the real points
fake_batch = generator.generate(BATCHLEN)
# Create the fake points

# Compute here the discriminator loss, using functions like torch.sum, torch.exp,
# torch.log, torch.softplus, using real_batch and fake_batch

disc_loss = 0 # FILL HERE
disc_loss.backward()
optim_disc.step()
# train the generator
fake_batch = generator.generate(BATCHLEN)
# Compute here the generator loss, using fake_batch
gen_loss = 0 # FILL HERE
gen_loss.backward()
optim_gen.step()
if i%100 == 0:
print('step {}: discriminator: {:.3e}, generator: {:.3e}'.format(i, float(disc_loss), float(gen_loss)))
# plot the result
real_batch = generate_batch(1024)
fake_batch = generator.generate(1024).detach()
plt.scatter(real_batch[:,0], real_batch[:,1], s=2.0, label='real data')
plt.scatter(fake_batch[:,0], fake_batch[:,1], s=2.0, label='fake data')
plt.show()
``````

So, I can’t use pre built functions, but have to use:
For the discriminator loss: functions like torch.sum, torch.exp, torch.log, torch.softplus, using real_batch and fake_batch
For the generator loss: fake_batch

The thing I’m pretty sure of, is that first, I have to create two new variables:
disc_real = discriminator(real_batch)
disc_fake = discriminator(fake_batch)

but then, I don’t really know what to do. I thought about beginning with something like:
disc_loss = torch.log(disc_real) + torch.log(1. - disc_fake)

``````
class WGANLoss(torch.nn.Module):

def __init__(self):
super(WGANLoss, self).__init__()

def forward(self,predicted,true):
WGAN = (-2*true+1) * torch.mean(predicted)
return WGAN

def forward_D(latent_code,Img_real):
Img_fake = G(latent_code)
d_real = D(Img_real)
d_fake = D(Img_fake.detach())

return d_real,d_fake

def forward_G(latent_code):
Img_fake = G(latent_code)
d_fake = D(Img_fake)

return d_fake

def backward_G(d_fake,optim_G):
g_loss = compute_G_loss(d_fake)
g_loss.backward()
optim_G.step()

def backward_D(d_real,d_fake,optim_D):
d_loss = compute_D_loss(d_real,d_fake)
d_loss.backward(retain_graph=False)
optim_D.step()

def compute_G_loss(d_fake):

def compute_D_loss(d_real,d_fake):

'''
Define here latent_codes, Imgs, G (Generator) and D (Discriminator)
'''

for latent_code,Img_real in zip(latent_codes,Imgs):

# ===update D===
d_real,d_fake = forward_D(latent_code,Img_real)
backward_D(d_real,d_fake,optim_D)

# ===update G===