GAN only learns properly when using detach()

Hello,

I am not quite sure how to categorize this question, so I am sorry in advance for this.

I am currently training a GAN to generate medical images. My code looks like this:

            # CREATE GENERATED OUTPUT
            fake = gen_model(data).to(device)

            # TRAIN DISCRIMINATOR
            discriminator.zero_grad()
            disc_pred_fake = nn.Sigmoid()(discriminator(fake.detach()).to(device))
            disc_pred_real = nn.Sigmoid()(discriminator(target_img.to(device)).to(device))
            real_label = torch.ones(data.shape[0], 1).to(device)
            fake_label = torch.zeros(data.shape[0], 1).to(device)
            loss_disc_fake = nn.BCELoss()(disc_pred_fake, fake_label)
            loss_disc_real = nn.BCELoss()(disc_pred_real, real_label)
            loss_disc = (loss_disc_real + loss_disc_fake) / 2
            loss_disc.backward()
            opt_disc.step()

            # TRAIN GENERATOR

            gen_model.zero_grad()
            disc_pred_fake = nn.Sigmoid()(discriminator(fake.detach()).to(device))
            L1loss = nn.L1Loss()(fake, target_img)
            loss_gen = loss_func(disc_pred_fake, real_label)

            full_loss = L1loss + loss_gen

            full_loss.backward()
            opt_gen.step()

This generates very good results, but what I find strange is that if I remove

disc_pred_fake = nn.Sigmoid()(discriminator(fake.detach()).to(device)) 

and instead use

disc_pred_fake = nn.Sigmoid()(discriminator(fake).to(device))

my results are becoming complete nonsense, i.e. returning images that are almost only 1 etc. (I normalize between -1 and 1).
But when browsing through implementations of GANs I only see people using:

disc_pred_fake = nn.Sigmoid()(discriminator(fake).to(device))

instead of the version with detach(). Am I doing it wrong?
If so, what does the detach() do in my case?
Am I implementing the GAN in a wrong manner when using the version with .detach(), i.e. doing strange things with the gradient and thus influencing the result in a good manner by coincidence?

I am really grateful for any kind of help since I am really stuck on this :frowning:

Best regards

Yes, since your generator is now detached from the computation graph and its parameters (used to create fake) won’t get any gradients. You can double check it by printing their .grad attribute, which should show None for all parameters of the generator.
Also, since gradients were never computed the optimizer also won’t update anything in the generator unless you are using weight decay.

Thanks a lot for your answer!

I took another look at the example implementation from this site: DCGAN Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation
and below in the code there is this line:

with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()

Why is the detach() called here? My understanding was now: .detach() removes the gradients, but torch.no_grad() tells pytorch to not calculate gradients anyways.
I am asking since in my evaluation im not using detach(), I just use with.torch.no_grad().

Sorry if that question is rudimentary, but I really became unsure with the pitfalls .detach() has.

The posted code snippet is used for testing only as mentioned in the comment:

# Check how the generator is doing by saving G's output on fixed_noise

and detaching the output won’t be needed since the no_grad context already avoids creating a computation graph.