Hello,
I am not quite sure how to categorize this question, so I am sorry in advance for this.
I am currently training a GAN to generate medical images. My code looks like this:
# CREATE GENERATED OUTPUT
fake = gen_model(data).to(device)
# TRAIN DISCRIMINATOR
discriminator.zero_grad()
disc_pred_fake = nn.Sigmoid()(discriminator(fake.detach()).to(device))
disc_pred_real = nn.Sigmoid()(discriminator(target_img.to(device)).to(device))
real_label = torch.ones(data.shape[0], 1).to(device)
fake_label = torch.zeros(data.shape[0], 1).to(device)
loss_disc_fake = nn.BCELoss()(disc_pred_fake, fake_label)
loss_disc_real = nn.BCELoss()(disc_pred_real, real_label)
loss_disc = (loss_disc_real + loss_disc_fake) / 2
loss_disc.backward()
opt_disc.step()
# TRAIN GENERATOR
gen_model.zero_grad()
disc_pred_fake = nn.Sigmoid()(discriminator(fake.detach()).to(device))
L1loss = nn.L1Loss()(fake, target_img)
loss_gen = loss_func(disc_pred_fake, real_label)
full_loss = L1loss + loss_gen
full_loss.backward()
opt_gen.step()
This generates very good results, but what I find strange is that if I remove
disc_pred_fake = nn.Sigmoid()(discriminator(fake.detach()).to(device))
and instead use
disc_pred_fake = nn.Sigmoid()(discriminator(fake).to(device))
my results are becoming complete nonsense, i.e. returning images that are almost only 1 etc. (I normalize between -1 and 1).
But when browsing through implementations of GANs I only see people using:
disc_pred_fake = nn.Sigmoid()(discriminator(fake).to(device))
instead of the version with detach(). Am I doing it wrong?
If so, what does the detach() do in my case?
Am I implementing the GAN in a wrong manner when using the version with .detach(), i.e. doing strange things with the gradient and thus influencing the result in a good manner by coincidence?
I am really grateful for any kind of help since I am really stuck on this
Best regards