GANS Loss Function Question

Hi Pytorch Discussion, found a lot of help here on previous projects so wanted to say thanks first!

I am currently playing around with a Generative model that should learn to create Album Covers from a dataset I found on Kaggle.

Here is my generator:

class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()  

    kernel_size = 4
    padding = 1
    stride = 2
    alpha = 0.2
    size = 4 * 4 * 1024
    
    # https://www.ritchievink.com/blog/2018/07/16/generative-adversarial-networks-in-pytorch-the-distribution-of-art/
    # We take a vector of d_dim -> to length of 4 * 4 * 1024;
    # This will become a 4 x 4 image with 1024 channels, for which we insert into network
    # We blow up the image to 128 x 128 with 3 channels. 

    self.input = nn.Linear(d_dim, size)
    self.net = nn.Sequential(
      nn.BatchNorm2d(1024),
      nn.LeakyReLU(alpha),
      nn.ConvTranspose2d(1024, 512, kernel_size, stride, padding),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(alpha),
      nn.ConvTranspose2d(512, 512, kernel_size, stride, padding),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(alpha),
      nn.ConvTranspose2d(512, 256, kernel_size, stride, padding),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(alpha),
      nn.ConvTranspose2d(256, 128, kernel_size, stride, padding),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(alpha),
      nn.ConvTranspose2d(128, 3, kernel_size, stride, padding),
      nn.Tanh()
    )
  
  def forward(self, z):
    x = self.input(z)
    # Note that this outputs an image with values -1 -> 1; we will 
    # need to make sure to scale to 256 after
    return self.net(x.view(-1, 1024, 4, 4))

And here is my discriminator:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        kernel_size = 4
        padding = 1
        stride = 2
        alpha = 0.2

        # Note that due to our generators outputs, we will need to make sure real
        # images are scaled properly.
        
        self.net = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size, stride, padding),
            nn.LeakyReLU(alpha),
            nn.Conv2d(128, 256, kernel_size, stride, padding),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(alpha),
            nn.Conv2d(256, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.Conv2d(512, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.Conv2d(512, 1024, kernel_size, stride, padding),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(alpha),
        )
        self.output = nn.Linear(4 * 4 * 1024, 1)
        
    def forward(self, x):
        x = self.net(x)
        # The net downsizes 128 x 128 by stride = 2 every layer -> 4 x 4 with 
        # 1024 channels
        x = torch.reshape(x, (-1, 4 * 4 * 1024))
        x = self.output(x)
        
        # if self.training:
        #     return x
        
        return torch.sigmoid(x)

I have a pretty typical training function that just iterates over these two functions for a certain amount of epochs:

def D_train(x):
  D.zero_grad()

  # Creating inputs of real data
  realX, realY = x.to(device), torch.ones(x.size(0), 1).to(device)
  D_real_output = D(realX)
  D_real_loss = criterion(D_real_output, realY)
  
  # Creating inputs of generated data
  z = torch.randn(x.size(0),d_dim).to(device)
  fakeX, fakeY = G(z), torch.zeros(x.size(0), 1).to(device)
  D_fake_output = D(fakeX)
  D_fake_output = D_fake_output.view(x.size(0), 1)
  D_fake_loss = criterion(D_fake_output, fakeY)
  
  # Take the average
  D_loss = Variable((D_fake_loss + D_real_loss)/2, requires_grad=True)
  D_loss.backward()
  D_optimizer.step()     

  return D_loss.item(), x.size(0)

def G_train(size):
  G.zero_grad()

  z = torch.randn(size,d_dim).to(device)
  y = torch.ones(size, 1).to(device)

  G_output = G(z)
  D_output = D(G_output)
  D_output = D_output.view(size, 1)
  G_loss = Variable(criterion(D_output, y), requires_grad=True)

  G_loss.backward()
  G_optimizer.step()
 
  return G_loss.item()

So my issue is that I have fed my GANS the full dataset (~80000 in batches of 64) for more than 15 epochs and the discriminator and generator’s loss function continues to remain constant. The printing of generated samples clearly show no learning either. I’ve been rumaging over my code for a while now and I am very lost and would really appreciate any insights into what might be wrong. Here is the link to the full code:

https://github.com/alexlin51/GANs-Album-Covers/blob/main/Album_Cover_Fun.ipynb

Thank you again!

You are detaching the loss tensors by wrapping them into the deprecated Variable in:

D_loss = Variable((D_fake_loss + D_real_loss)/2, requires_grad=True)

so you should remove the Variable usage.
(Note that the same issue would be caused, if you rewrap the loss in a new tensor.)

Also, if you are training the discriminator, you usually don’t want to calculate any gradients for the generator, so you would most likely want to detach() the generator output before feeding it to D.
The DCGAN tutorial shows a good usage of a GAN training.

1 Like

I appreciate you!! Thanks!