How to update discriminator and generator multiples times in GAN? RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation?

Hi all, I am trying to convert a work coded by TensorFlow into PyTorch and it relates to GAN. In this code, I want to update the discriminator twice and generator multiple times. But there always an error that I cannot fix, which is " RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation".

I know that this is because we cannot compute gradients and update multiple times because we will change the weights, discriminator, or generator inplace, and cause the later gradient computation an error. However, this works can be done in TensorFlow, because they simply take out the gradients and do the update multiple times.

I really don’t know how to handle this problem, and I’ve been stuck in this for days. Could someone please help me with it. Many thanks. Below is my codes.

            # ================================================================== #
            #                        Get the real, noise                         #
            # ================================================================== #

            x_real, targets = images.cuda(), targets.cuda()
            x_real = 2 * x_real - 1

            x_noise = maximizer(x_real)
            x_perturbed = x_real + epsilon[dataset] * x_noise

            d_out_real = minimizer(x_real)
            d_out_noise = minimizer(x_perturbed)

            d_loss_real = criterion(d_out_real, targets)
            d_loss_noise = criterion(d_out_noise, targets)

            d_loss = d_loss_real + d_loss_noise
            g_loss = -1.0 * d_loss_noise

            # update d_loss
            reset_grad()
            d_loss.backward(retain_graph=True)
            min_optimizer1.step()

            # ================================================================== #
            #                        Train the generator                         #
            # ================================================================== #

            # 1st step
            reset_grad()
            g_loss.backward(retain_graph=True)
            max_optimizer.step()

            # 2nd step
            x_noise2 = maximizer(x_real)
            x_perturbed2 = x_real + epsilon[dataset] * x_noise2
            d_out_noise2 = minimizer(x_perturbed2)
            d_loss_noise2 = criterion(d_out_noise2, targets)
            g_loss2 = -1.0 * d_loss_noise2
            reset_grad()
            g_loss2.backward(retain_graph=True)
            max_optimizer.step()

            # 3rd step
            x_noise3 = maximizer(x_real)
            x_perturbed3 = x_real + epsilon[dataset] * x_noise3
            d_out_noise3 = minimizer(x_perturbed3)
            d_loss_noise3 = criterion(d_out_noise3, targets)
            g_loss3 = -1.0 * d_loss_noise3
            reset_grad()
            g_loss3.backward(retain_graph=True)
            max_optimizer.step()

            # 4th step
            x_noise4 = maximizer(x_real)
            x_perturbed4 = x_real + epsilon[dataset] * x_noise4
            d_out_noise4 = minimizer(x_perturbed4)
            d_loss_noise4 = criterion(d_out_noise4, targets)
            g_loss4 = -1.0 * d_loss_noise4
            reset_grad()
            g_loss4.backward(retain_graph=True)
            max_optimizer.step()

            # 5th step
            x_noise5 = maximizer(x_real)
            x_perturbed5 = x_real + epsilon[dataset] * x_noise5
            d_out_noise5 = minimizer(x_perturbed5)
            d_loss_noise5 = criterion(d_out_noise5, targets)
            g_loss5 = -1.0 * d_loss_noise5
            reset_grad()
            g_loss5.backward(retain_graph=True)
            max_optimizer.step()

            # 6th step: virtual step
            x_noise6 = maximizer(x_real)
            x_perturbed6 = x_real + epsilon[dataset] * x_noise6
            d_out_noise6 = minimizer(x_perturbed6)
            g_loss6 = -1.0 * criterion(d_out_noise6, targets)
            d_loss_noise6 = criterion(d_out_noise6, targets)

            reset_grad()
            g_loss6.backward(retain_graph=True)
            max_virtual_optimizer.step()

            # ================================================================== #
            #                      Train the discriminator                       #
            # ================================================================== #

            # compute noiseB
            x_noiseB = maximizer(x_real)
            x_perturbedB = x_real + epsilon[dataset] * x_noiseB
            d_out_noiseB = minimizer(x_perturbedB)
            d_loss_noiseB = criterion(d_out_noiseB, targets)

            # combine: loss_real + loss_noise6 + loss_noiseB
            d_loss_full = d_loss_real + g_loss6 + gamma * (g_loss6 - d_loss_noiseB) / g_step_size

            reset_grad()
            d_loss_full.backward(retain_graph=True)
            min_optimizer1.step()

            # restore para for G
            minus_g_loss = -1.0 * g_loss6
            reset_grad()
            minus_g_loss.backward(retain_graph=True)
            max_virtual_optimizer.step()

This would be wrong, wouldn’t it?

Could you take a look at this post and check, if you are using the same workflow?

Thanks Ptrblck, but I still not able to fix the bug. I am trying to reproduce the codes here: https://github.com/whxbergkamp/RobustDL_GAN/blob/master/cifar10/adversarial_networks/train_gan_cifar10.py , where their workflow is like this:
G_optimizer1
G_optimizer2
G_optimizer3
G_optimizer4
G_optimizer5
G_virtual_optimizer (forward)
D_optimizer
G_virtual_optimizer (backward)

They are using tensorflow and they can manipulate the gradients directly. But in pytorch, I need to do optimizer.step() for G_virtual_optimizer (forward),
D_optimizer and G_virtual_optimizer (backward) based on G_virtual_loss and D_loss, where
G_virtual_loss(forward) = g_loss6 = -1*criterion(d_out_noise6, labels);
G_virtual_loss(backward) = -1 * g_loss6 = criterion(d_out_noise6, labels);
D_loss = d_loss_real + d_loss6 + d_loss_noiseB = d_loss_real + criterion(d_out_noise6, labels) + d_loss_noiseB
.
The thing is that g_loss6 and d_loss_noise6 share the same variable – d_out_noise6. I tried to use g_loss6_copy = g_loss.detach() so that I can update G_virtual_optimizer (forward), D_optimizer and G_virtual_optimizer (backward) successfully. However, if I use detach(), the gradients explode very quickly. So, now I am switching back to PyTorch 1.4 to bypass this issue.

My questions are:

  1. if I am doing the updates in PyTorch 1.4, does it exactly reproduce the workflow as theirs in TensorFlow?
  2. if not, what’s the difference and how should I achieve the workflow?

I’m not deeply familiar with Tensorflow, but since stop_gradient ops were used, they would be equivalent to a detached tensor. If could be the right approach to detach g_loss (unsure, as I don’t fully understand the TF code) and the exploding loss might have another root cause.
Were you able to validate that you are indeed trying to use stale gradients to update already updated parameters?

Yes, I am trying to use criterion(output_logits, labels) multiple times to compute g_loss and d_loss. I checked the stop_gradient() in Tensorflow, and try to mimic this operation with detach() in Pytorch. But, if I use detach(), the error would be " variables have no grad". If I don’t use detach(), the error would be “one of the variables needed has been modified by an inplace operation”.

Since you mention the stable gradients, how can I make use of criterion(output_logits, labels) multiple times in Pytorch? I want to use this loss to update G (forward) once, D once, and G (backward) once.

You could call backward multiple times, if you use retain_graph=True during the call.
However, could you describe how the computation graphs are created, which losses are computed with which parameters, how the gradients should be calculated, and when the models would be updated?
A drawing on a piece of paper might be easier, as I’m currently unsure what you are trying to achieve and would like to avoid making the code “run” somehow without checking if that’s indeed what you want.