Why GAN Discriminator gets fake image twice?

or epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(DEVICE)
        batch_size = real.shape[0]

        # Discriminator
        noise = torch.randn(batch_size, z_dim).to(DEVICE)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        # Generator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

In the Generator part you can see that I use:

      # Generator
        output = disc(fake).view(-1)
        # Generator
        # output = disc(fake).view(-1)
        lossG = criterion(disc_fake, torch.ones_like(output))

Which was already computed in the Discriminator part above.
Why does the following not work?:

RuntimeError: 
one of the variables needed for gradient computation has been modified by 
an in-place operation: [torch.cuda.FloatTensor [128, 1]], which is output 0 of 
AsStridedBackward0, is at version 18754; expected version 18753 instead. Hint: enable 
anomaly detection to find the operation that failed to compute its gradient, with 
torch.autograd.set_detect_anomaly(True).

I don’t understand why this redundant computation is done and why the result can’t be copied.

I tried to do the optimization step after all gradients were calculated which resolved the problem:

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(DEVICE)
        batch_size = real.shape[0]

        # Discriminator:
        noise = torch.randn(batch_size, z_dim).to(DEVICE)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)

        # Generator
        #output = disc(fake).view(-1)
        # now i can use disc_fake from above
        lossG = criterion(disc_fake, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_disc.step()
        opt_gen.step()

Maybe there is a misunderstanding how gradients work in pytorch on my side. Could someone please share some insights ?

The error seems to suggest there is an inplace operation in the model.

https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

For example, you should not do:

def forward(self, x, y):
    ...
    x+=self.final(y)
    return x

I’m afraid that this is not the reason.

It is definitely related to the computational graph and the way gradients are updated.

Just for completeness here are the models:

class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

Try changing this line(add .detach()):

        ...
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1)  # < ---- change here
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        ...

Without a copyable working sample, this is a bit of guesswork. But that seems to be one of the issues, at least.

Detaching has essentially the same effect as:

        lossD.backward(retain_graph=True)

which I’ve included.

As mentioned above I provided a remedy for my problem but I don’t understand why I should pass the noise image to the Discriminator twice instead of reusing the output after the first pass.

My Hypothesis is that after the first pass the Discriminator weights are already updated a little bit so the Generator does not need to Generate for the initial Discriminator but for the Discriminator after a first weight update.

I just don’t understand why we would do something like that.

Not exactly, since detaching the fake tensor will avoid computing gradients in the generator using the discriminator loss.
The training logic is:

  • train the discriminator with a real input and real target to learn how to detect real images
  • create a fake output using the generator
  • train the discriminator using the detached fake generator output with fake labels to learn how to detact fake images
  • the generator should not get any gradients at this stage since the target is wrong
  • use the fake generator output in the discriminator with real targets, calculate the gradients in the generator to learn how to create fake images which will “confuse” the discriminator

If you try to retain the graph the previously optimizerD.step() method will update the discriminator parameters and the kept forward activations will thus be stale. The next backward() call will thus fail.

Almost everything is clear now. I changed the code according to your training logic
(Please feel free to critique if it’s not)

What still didn’t click is: (Question in Code Comment)

#Standard Train Loop
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(DEVICE)
        batch_size = real.shape[0]
        
        # Discriminator:
        ##############
        # Feed Real Image to discriminator and provide label = 1
        disc_output_real_img = discriminator(real).view(-1)
        disc_loss_real_img = criterion(disc_output_real_img, torch.ones_like(disc_output_real_img))
        
        # Generate Noise (Fake Image) and feed to generator and provide label = 0
        noise = torch.randn(batch_size, z_dim).to(DEVICE)
        fake_img = generator(noise)        
        disc_output_fake_img = discriminator(fake_img.detach()).view(-1)
        disc_loss_fake_img = criterion(disc_output_fake_img, torch.zeros_like(disc_output_fake_img))
        
        # Combine loss for real and fake image and update weights
        disc_loss = (disc_loss_real_img + disc_loss_fake_img) / 2
        discriminator.zero_grad()
        #disc_loss.backward(retain_graph=True)
        optimizer_disc.step()
        
        

        # Generator loss:
        ##############
        output = discriminator(fake_img).view(-1) # ----> Isnt the output here redundant as it 
                                                                          #         calculated above already??
                                                                          #         See above: disc_output_fake_img = ...
                                                                          #         Why pass fake again to discriminator?
        gen_loss = criterion(output, torch.ones_like(output))
        generator.zero_grad()
        gen_loss.backward()
        optimizer_gen.step()
        # Generator loss:
        ##############
        output = discriminator(fake_img).view(-1) # ----> Isnt the output here redundant as it 
                                                                          #         calculated above already??
                                                                          #         See above: disc_output_fake_img = ...
                                                                          #         Why pass fake again to discriminator?

No it’s not redundant since you have already updated the discriminator via optimizer.disc_step() so output will be different.
If you try to use the same output to calculate the gradients in the generator it will fail since the parameters of the discriminator were updated while the forward activations are still stored from the previous forward pass (kept via retain_graph=True).

1 Like