My simple GAN doesn't learn

Hey,
I tried to take much inspiration from the following blog post: https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f.
Wrote the following code to create a GAN the learns to produce data that looks like it came from a normal distibution with mean 4 and std 1.25. It doesn’t learn and would appreciate some help and advises. Also, I don’t understand why I have to write the retain_graph=True…?

import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from matplotlib import pyplot as plt

def get_distribution_sampler(mu, sigma):
  return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))

def get_generator_input_sampler():
  return lambda m, n: torch.rand(m, n)

class Generator(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(Generator, self).__init__()
    self.main = nn.Sequential(
              nn.Linear(input_size, hidden_size),
              nn.ReLU(),
              nn.Linear(hidden_size, hidden_size),
              nn.ReLU(),
              nn.Linear(hidden_size, output_size)              )
  
  def forward(self, x):
    return self.main(x)

class Dicriminator(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(Dicriminator, self).__init__()
    self.main = nn.Sequential(
              nn.Linear(input_size, hidden_size),
              nn.ReLU(),
              nn.Linear(hidden_size, hidden_size),
              nn.ReLU(),
              nn.Linear(hidden_size, output_size),
              nn.Sigmoid()
              )
  
  def forward(self, x):
    return self.main(x)

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def train():
  g_input_size = 1
  g_hidden_size = 10
  g_output_size = 1
  
  d_input_size = 500
  d_hidden_size = 10
  d_output_size = 1
  
  minibatch_size = d_input_size

  d_lr = g_lr = 1e-3
  
  num_epochs = 5000
  print_interval = 100
  d_steps = g_steps = 20

  d_sampler = get_distribution_sampler(4, 1.25)
  g_sampler = get_generator_input_sampler()

  G = Generator(g_input_size, g_hidden_size, g_output_size)
  D = Dicriminator(d_input_size, d_hidden_size, d_output_size)

  criterion = nn.BCELoss()

  d_optimizer = optim.Adam(D.parameters(), d_lr)
  g_optimizer = optim.Adam(G.parameters(), d_lr)

  for epoch in range(num_epochs):
    for _ in range(d_steps):
      D.zero_grad()

      # train D on real
      d_real_data = d_sampler(minibatch_size)
      d_real_decision = D(d_real_data)
      d_real_error = criterion(d_real_decision, torch.ones(1))

      d_real_error.backward(retain_graph=True)

      # train D on fake
      gen_input = g_sampler(minibatch_size, g_input_size)
      d_fake_data = G(gen_input)
      d_fake_decision = D(d_fake_data.detach().view(-1))
      d_fake_error = criterion(d_fake_decision, torch.zeros(1))

      d_fake_error.backward(retain_graph=True)

      # update gradients for train
      d_optimizer.step()

      dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]
    
    for _ in range(g_steps):
      G.zero_grad()

      gen_input = g_sampler(minibatch_size, g_input_size)
      g_fake_data = G(gen_input)
      dg_fake_decision = D(g_fake_data.view(-1))
      g_error = criterion(d_fake_decision, torch.ones(1))
      
      g_error.backward(retain_graph=True)
      g_optimizer.step()
      ge = extract(g_error)[0]

    if epoch % print_interval == 0:
      print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
                  (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
      print("Plotting the generated distribution for epoch " + str(epoch))
      values = extract(g_fake_data)
      print(" Values: %s" % (str(values)))
      plt.hist(values, bins=50)
      plt.xlabel('Value')
      plt.ylabel('Count')
      plt.title('Histogram of Generated Distribution')
      plt.grid(True)
      plt.show()

train()

you can never output N(4, 125) with sigmoid restricting G to output (-1,1)

1 Like

You are right, silly me. Took off the Sigmoid from the Generator, but it stil doesn’t work…