Couldn't understand how .detach() is changing the Generator working

Below is my code for training a GAN:

def train_discriminator(optimizer, real_data, fake_data):
    #set optimizer gradients to zero to store fresh gradients
    #adversarial discriminator loss
    prediction_real = disc(real_data) 
    error_real = loss_bce(prediction_real, real_data_target(real_data.size(0)))
    #adversarial generative loss
    prediction_fake = disc(fake_data)    
    error_fake = loss_bce(prediction_fake, fake_data_target(fake_data.size(0)))
    total_error = (error_real + error_fake)/2

    total_error.backward() #backprop
    optimizer.step() #update weights
    return (error_real + error_fake)/2, prediction_real, prediction_fake    

def train_generator(optimizer, real_data, fake_data):
    #content loss
    content_loss = loss_mse(real_data, fake_data)

    #adversarial generative loss
    prediction_fake = disc(fake_data)
    adv_loss = loss_bce(prediction_fake, real_data_target(fake_data.size(0)))
    total_gen_loss = content_loss + 0.001*adv_loss
    total_gen_loss.backward() #backprop
    optimizer.step() #update weights
    return total_gen_loss

for epoch in range(EPOCHS):
    g_err_sum = 0.0
    d_err_sum = 0.0
    for real_batch in training_loader:
        real_batch = real_batch.cuda()
        real_data = real_batch 
        fake_data = gen(real_data[:,0:2]) # Line 1
        #train discriminator
        d_error, _ ,_ = train_discriminator(d_opt, real_data, fake_data.detach()) # Line 2
        #train generator
        g_error = train_generator(g_opt, real_data, fake_data)
    if epoch%10==0:
        print("Epoch no : {} Discriminator error: {}   Generator error: {}".format(epoch,d_err_sum/num_batches,g_err_sum/num_batches))

In the main loop, changing
gen(real_data[:,0:2]) to gen(real_data[:,0:2]).detach()
is failing to train the generator. The initial snippet of code is successfully making a GAN model.

Kindly someone explain how .detach() is making a difference, especially in context with the train_generator() function written above.

Adding a .detach() basically breaks the gradient connection.
This means that any gradient flowing back towards fake_data won’t be propagated to the generator. So no gradient will be populated.

1 Like

Alright I got that part, but why do we need to detach while passing it to the discriminator and not while passing it to the generator?

Because you want the discriminator loss to only compute gradients for the discriminator and not the generator.
The .detach() here allows you to make sure this happens.

If you do it for the generator loss, then the generator loss won’t contribute to the generator gradient which is not what you want.


Thanks a lot @albanD. I understood it! :slight_smile:

In this example, it seems that he zeros out the gradients of the generator before updating the weights. In this context, does it still make a difference to the training process if we detach or not, except for the extra unnecessary computation?

I am printing gradients of a layer of Generator, with and without using .detach(). In my thinking the gradients of weights should not change when calling discriminator_loss.backward while using .detach()(since .detach() ensures the gradients are not being backpropagated to the generator), but I am observing opposite behavior. Irrespective of usage of .detach() the before and after gradients value are different when discriminator_loss.backbard() is called. Can anyone point out where I am wrong?

        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)

        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        print(netG.main[0].weight[0][0][1:3], "generator before netG(noise)")
        fake = netG(noise)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        #output = netD(fake).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        # In my thinking both the following print functions should provide same values when using .detach()
        print(netG.main[0].weight[0][0][1:3], "generator before errD_fake.backward() step")
        print(netG.main[0].weight.grad[0][0][1:3], "generator grad after errD_fake.backward() step")