Trouble understanding how the optimizer works


I have written the GAN network based on my understanding of the DCGAN tutorial and I am getting a runtime error that I do not fully understand. The code I use to train the networks is given below

Generator_Loss = []
Discriminator_Loss = []
generated_images = []

for epoch in range(args['epochs']):
  for i,data in enumerate(celebA_dataLoader):
    # Training disriminator on real data
    real_data_on_gpu = data[0].to(device)
    y_true = torch.full((args["batch_size"], 1), 1, device=device)
    y_pred_real = discriminator_net.forward(real_data_on_gpu)
    real_loss = loss_criterion(y_pred_real, y_true)
    # Training discriminator on fake data
    latent_variables = torch.as_tensor(rand.randn(args["batch_size"], args["dim_Z"], 1, 1).astype(np.float32), device=device)
    fake_data_on_gpu = generator_net.forward(latent_variables)
    y_pred_fake = discriminator_net.forward(fake_data_on_gpu)
    fake_loss = loss_criterion(y_pred_fake, y_true.fill_(0))
    discriminator_loss = real_loss + fake_loss
    # Training generator
    y_pred_fake = discriminator_net.forward(fake_data_on_gpu)
    generator_loss = loss_criterion(y_pred_fake, y_true.fill_(1))
    if i % 50 == 0:
      print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G:'
                  % (epoch, args["epochs"], i, len(dataloader), 
                     generator_loss.item(), discriminator_loss.item()))
    if epoch == args["epochs"] - 1:
      with torch.no_grad():
        fake_images = generator_net(latent_variables).detach.cpu()

I am getting the following run time error.

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

at the line generator_loss.backward()

I have searched about this error, but I am still not clear how the optimizer works. My understanding is that, when there is a forward pass, a computational graph of the network is created and it gets deleted after a backward pass. In my case I did one forward pass for the generator and when I try to do the backward pass I am getting this error.

Could someone please help me understand why this is happening ?

Warm Regards,

Could you try to detach the generator output when you train the discriminator on the fake data?

y_pred_fake = discriminator_net(fake_data_on_gpu.detach())

This will make sure to create gradients in the generator in this step.
When you train your generator you should just pass the fake input without detaching as shown in the DCGAN example.

Also as a small side note: you should’n use forward but just call the model directly with the input as shown in my code snippet.
This will make sure to properly register all hooks if necessary.

Hi @ptrblck,

Thank you for the suggestions. It works perfectly now. Although could you please tell me the rationale behind your suggestions ?

If you don’t detach fake_data_on_gpu the fake_loss.backward() call will compute the gradients in the discriminator and the generator, since the computation graph was created as
random input -> generator -> discriminator -> loss function

In the discriminator update step you don’t need the gradients in the generator for the “fake” label, so you could save computing these gradients by detaching the generator output as:
random input -> generator -> detach output -> discriminator -> loss function

Now the backward call will only calculate all gradients to the point the graph was detached.

In the generator update step, you are using a “real” label and need the gradients in the generator, so you have to use the first approach without detaching the graph.

That being said, you will not only save some unnecessary computation (and maybe even wrong gradients if you don’t clear them), but you might also get an error as in your case.

After the fake_loss.backward() call, intermediate activations in the models will be cleared to save memory. Since you need these intermediates for the generator_loss.backward() call, the error was raised.

Let me know, if something is still unclear.

1 Like

@ptrblck That was a wonderful explanation. Thank you. It makes perfect sense.

1 Like