In order to learn about GANs I adapted code for a WGAN and played around a bit (from Machine-Learning-Collection/ML/Pytorch/GANs/3. WGAN at master · aladdinpersson/Machine-Learning-Collection · GitHub).
I noticed that if I add a TanH function to the discriminator model. Training becomes unstable and generally much slower. Why is that?
Does someone has an idea how I can converge during training the output of the discriminator of a WGAN to match a certain label (e.g. -1 and +1 or 0 and 1) without adding a transfer function to the model itself?
This is the current training loop:
# Train Critic: max E[critic(real)] - E[critic(fake)] for _ in range(CRITIC_ITERATIONS): noise = torch.randn(cur_batch_size, RAND_DIM, 1, 1).to(device) fake = gen(noise) critic_real = critic(real).reshape(-1) critic_fake = critic(fake).reshape(-1) mean_real = torch.mean(critic_real) mean_fake = torch.mean(critic_fake) loss_critic = mean_fake - mean_real critic.zero_grad() loss_critic.backward(retain_graph=True) opt_crit.step() # clip critic weights between -0.01, 0.01 for p in critic.parameters(): p.data.clamp_(-0.01, 0.01) # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)] loss_gen = -torch.mean(critic(fake).reshape(-1)) gen.zero_grad() loss_gen.backward() opt_gen.step()