Hey,
I tried to take much inspiration from the following blog post: https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f.
Wrote the following code to create a GAN the learns to produce data that looks like it came from a normal distibution with mean 4 and std 1.25. It doesn’t learn and would appreciate some help and advises. Also, I don’t understand why I have to write the retain_graph=True…?
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from matplotlib import pyplot as plt
def get_distribution_sampler(mu, sigma):
return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))
def get_generator_input_sampler():
return lambda m, n: torch.rand(m, n)
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size) )
def forward(self, x):
return self.main(x)
class Dicriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Dicriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
def extract(v):
return v.data.storage().tolist()
def stats(d):
return [np.mean(d), np.std(d)]
def train():
g_input_size = 1
g_hidden_size = 10
g_output_size = 1
d_input_size = 500
d_hidden_size = 10
d_output_size = 1
minibatch_size = d_input_size
d_lr = g_lr = 1e-3
num_epochs = 5000
print_interval = 100
d_steps = g_steps = 20
d_sampler = get_distribution_sampler(4, 1.25)
g_sampler = get_generator_input_sampler()
G = Generator(g_input_size, g_hidden_size, g_output_size)
D = Dicriminator(d_input_size, d_hidden_size, d_output_size)
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), d_lr)
g_optimizer = optim.Adam(G.parameters(), d_lr)
for epoch in range(num_epochs):
for _ in range(d_steps):
D.zero_grad()
# train D on real
d_real_data = d_sampler(minibatch_size)
d_real_decision = D(d_real_data)
d_real_error = criterion(d_real_decision, torch.ones(1))
d_real_error.backward(retain_graph=True)
# train D on fake
gen_input = g_sampler(minibatch_size, g_input_size)
d_fake_data = G(gen_input)
d_fake_decision = D(d_fake_data.detach().view(-1))
d_fake_error = criterion(d_fake_decision, torch.zeros(1))
d_fake_error.backward(retain_graph=True)
# update gradients for train
d_optimizer.step()
dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]
for _ in range(g_steps):
G.zero_grad()
gen_input = g_sampler(minibatch_size, g_input_size)
g_fake_data = G(gen_input)
dg_fake_decision = D(g_fake_data.view(-1))
g_error = criterion(d_fake_decision, torch.ones(1))
g_error.backward(retain_graph=True)
g_optimizer.step()
ge = extract(g_error)[0]
if epoch % print_interval == 0:
print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s), Fake Dist (%s) " %
(epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
print("Plotting the generated distribution for epoch " + str(epoch))
values = extract(g_fake_data)
print(" Values: %s" % (str(values)))
plt.hist(values, bins=50)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Generated Distribution')
plt.grid(True)
plt.show()
train()