Joint AE/GAN training


I’m trying to jointly train a convolutional network as an AE and GAN but I’m not sure that I have the training routine set up correctly. Would greatly appreciate any help.

for epoch in range(n_iter):
    for i, (batch, _) in enumerate(dataloader):

        current_batch_size = batch.size(0)

        #Train as AE


        input = Variable(batch).cuda()

        encoded = _Encoder(input)
        encoded = encoded.unsqueeze(0)
        encoded = encoded.view(input.size(0), -1, 1, 1)
        reconstructed = _Decoder(encoded)

        reconstruction_loss = criterion(reconstructed, input)

        optim_decoder.step()  # here it's SGD

        #Train as GAN

        #Train Discriminator on real

        real_samples = input.clone()
        inference_real = _Discriminator(real_samples)
        labels = torch.FloatTensor(current_batch_size).fill_(real_label)
        labels = Variable(labels).cuda()
        real_loss = criterion(inference_real, labels)

        #Generate fake samples, z_d, 1, 1), 10)
        fake_samples = _Decoder(noise)

        #Train Discriminator on fake
        inference_fake = _Discriminator(fake_samples.detach())
        fake_loss = criterion(inference_fake, labels)
        discriminator_total_loss = real_loss + fake_loss

        #Update Decoder/Generator with how well it fooled Discriminator

        inference_fake_Decoder = _Discriminator(fake_samples)
        fake_samples_loss = criterion(inference_fake_Decoder, labels)
        optim_decoderGAN.step()  # here it's Adam

My goal is to have the Decoder/Generator map the real samples to specific locations in the latent space, and then generate potential candidates between/around the points that are mapped to real samples. Also, I know that the “stabilizing GANs” post says to use normal rather than uniform distributions and that my range is quite large. I started doing that because evaluating the trained DeeperDCGANs that I’ve been using seems to show that the usual Z.normal_(0, 1) is too small of a range, and the points are continually overwritten at each iteration. I also think that the feature space of my dataset is likely following a uniform rather than normal distribution.