Hi there
In order to learn about GANs I adapted code for a WGAN and played around a bit (from Machine-Learning-Collection/ML/Pytorch/GANs/3. WGAN at master · aladdinpersson/Machine-Learning-Collection · GitHub).
-
I noticed that if I add a TanH function to the discriminator model. Training becomes unstable and generally much slower. Why is that?
-
Does someone has an idea how I can converge during training the output of the discriminator of a WGAN to match a certain label (e.g. -1 and +1 or 0 and 1) without adding a transfer function to the model itself?
This is the current training loop:
# Train Critic: max E[critic(real)] - E[critic(fake)]
for _ in range(CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, RAND_DIM, 1, 1).to(device)
fake = gen(noise)
critic_real = critic(real).reshape(-1)
critic_fake = critic(fake).reshape(-1)
mean_real = torch.mean(critic_real)
mean_fake = torch.mean(critic_fake)
loss_critic = mean_fake - mean_real
critic.zero_grad()
loss_critic.backward(retain_graph=True)
opt_crit.step()
# clip critic weights between -0.01, 0.01
for p in critic.parameters():
p.data.clamp_(-0.01, 0.01)
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
loss_gen = -torch.mean(critic(fake).reshape(-1))
gen.zero_grad()
loss_gen.backward()
opt_gen.step()