Conditional GAN, Sampling from Gaussian, How to Correctly Use Variables

Hi there,

I am trying to make a conditional GAN that can sample from a gaussian for a specific mean. For example, say two mean labels are 0 and 10. For a mean label of 0, I would like the generator to return values ultimately b/w -3 and 3 for example and for mean label of 10, values b/w -7 and 13 or so. This task works for a non-conditional gan when all the values are just between one mean range i.e. it can sample from a gaussian around 0.

I generate “real” samples where half the values come from a gaussian with mean 0 and half come from a guassian of mean 10. I supply that same “mean list” to the generator, so the generator gets an input of a uniform noise value and a mean and is supposed to output just 1 value, which hopefully in the end would be a value that seems to have come from a gaussian of that mean.

However, to do this, I concatenate the generator output (1 value) with the mean and then feed that to the discriminator as a new variable. Could this be potentially redirecting the graph and making the gradients not update? For the code below: the objective is just wasserstein metric. fake is just the singular value generate from the generator which takes a noise value and a mean. fake_disc_input is the outputted generated value concatenated with the mean. extraD is just there to update the discriminator more often than the generator. lamD is a penalizing term, currently set to 10. When I run this, the gradD values are essentially 0, and nothing updates. Any ideas? Am I feeding the conditions in the wrong way? Thanks a bunch if so…

for iteration in iter_range:

 for extra in range(extraD):
    real, means, both_real = sample_real(bs, shape, conditional_means)
    
    fake = netG(sample_noise(bs, n_latent, conditional_means))
    fake_disc_input = Variable(torch.FloatTensor(np.column_stack((fake.data.numpy(),means.data.numpy()))), requires_grad = True)
    #data = torch.cat((fake, real), 0)

    optD.zero_grad()
    lossD = objective(netD(both_real), netD(fake_disc_input))
    gradD = grad(lossD * bs, fake_disc_input, create_graph=True)[0]
    reguD = gradD.norm(2, 1).clamp(1).mean()
    (lossD + lamD * reguD).backward()
    optD.step()

real, means, both_real = sample_real(bs, shape, conditional_means)
fake = netG(sample_noise(bs, n_latent, conditional_means)).data.numpy()
fake_disc_input = Variable(torch.FloatTensor(np.column_stack((fake,means.data.numpy()))), requires_grad = True)

optG.zero_grad()
lossG = - objective(netD(both_real), netD(fake_disc_input))
(lossG).backward()
optG.step()