Hello,
I’m new with pytorch (and also with GAN), and I need to compute the loss functions for both the discriminator and the generator.
The code is standard:
import torch.nn as nn
import torch.nn.functional as F
# Choose a value for the prior dimension
PRIOR_N = 25
# Define the generator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(PRIOR_N, 2)
self.fc2 = nn.Linear(2, 2)
def __call__(self, z):
h = F.relu(self.fc1(z))
return self.fc2(h)
def generate(self, batchlen):
z = torch.normal(torch.zeros(batchlen, PRIOR_N), 1.0)
return self.__call__(z)
# Define the discriminator
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 1)
def __call__(self, x):
h = F.relu(self.fc1(x))
return self.fc2(h)
# Number of times to train the discriminator between two generator steps
TRAIN_RATIO = 1
# Total number of training iterations for the generator
N_ITER = 20001
# Batch size to use
BATCHLEN = 128
generator = Generator()
optim_gen = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5,0.9))
discriminator = Discriminator()
optim_disc = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5,0.9))
for i in range(N_ITER):
# train the discriminator
for _ in range(TRAIN_RATIO):
discriminator.zero_grad()
# zero_grad() clear the gradient of all optimized torch tensor
real_batch = generate_batch(BATCHLEN)
# Create the real points
fake_batch = generator.generate(BATCHLEN)
# Create the fake points
# Compute here the discriminator loss, using functions like torch.sum, torch.exp,
# torch.log, torch.softplus, using real_batch and fake_batch
disc_loss = 0 # FILL HERE
disc_loss.backward()
optim_disc.step()
# train the generator
generator.zero_grad()
fake_batch = generator.generate(BATCHLEN)
# Compute here the generator loss, using fake_batch
gen_loss = 0 # FILL HERE
gen_loss.backward()
optim_gen.step()
if i%100 == 0:
print('step {}: discriminator: {:.3e}, generator: {:.3e}'.format(i, float(disc_loss), float(gen_loss)))
# plot the result
real_batch = generate_batch(1024)
fake_batch = generator.generate(1024).detach()
plt.scatter(real_batch[:,0], real_batch[:,1], s=2.0, label='real data')
plt.scatter(fake_batch[:,0], fake_batch[:,1], s=2.0, label='fake data')
plt.show()
So, I can’t use pre built functions, but have to use:
For the discriminator loss: functions like torch.sum, torch.exp, torch.log, torch.softplus, using real_batch and fake_batch
For the generator loss: fake_batch
The thing I’m pretty sure of, is that first, I have to create two new variables:
disc_real = discriminator(real_batch)
disc_fake = discriminator(fake_batch)
but then, I don’t really know what to do. I thought about beginning with something like:
disc_loss = torch.log(disc_real) + torch.log(1. - disc_fake)
Thanks for your help