Training GAN lead to "RuntimeError: Trying to backward through the graph a second time ..."

Hi,

I am trying to train a GAN to generate both image and labels. The generator has a common backbone and then two heads to generate image and label. I have two different discriminator: one for the images and one for labels.

Here a snippet illustrating a training iteration:

############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch

# Format batch
real_cpu = data[0].to(device)
patch = patch[0].to(device)
label_mask = label_mask[0].to(device)

requires_grad(netG, False)
requires_grad(netD_image, True)
requires_grad(netD_label, False)

b_size = real_cpu.size(0)
label = torch.rand((b_size,), device=device)*0.1 + 0.9
# Forward pass real batch through D
output_image = netD_image(real_cpu).view(-1)
# Calculate loss on all-real batch)
errD_real_img = criterion(output_image, label)

## Train with all-fake batch

# Generate fake image batch with G
fake_img, fake_label = netG(patch)
label = torch.rand((b_size,), device=device)*0.1 

# Classify all fake batch with D
output_image = netD_image(fake_img.detach()).view(-1)

# Calculate D's loss on the all-fake batch
errD_fake_img = criterion(output_image, label)

netD_image.zero_grad()

# Compute error of D as sum over the fake and the real batches
errD_img = errD_real_img + errD_fake_img
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_img.backward()
# Update D
optimizerD_img.step()

#same as before, with label discriminator
requires_grad(netG, False)
requires_grad(netD_image, False)
requires_grad(netD_label, True)

fake_img, fake_label = netG(patch)

output_fake_lmask = netD_label(fake_label.detach()).view(-1)
output_real_lmask = netD_label(label_mask.detach()).view(-1)

label = torch.rand((b_size,), device=device)*0.1 + 0.9
errD_real_lmask = criterion(output_real_lmask, label)

label = torch.rand((b_size,), device=device)*0.1
errD_fake_lmask = criterion(output_fake_lmask, label)

netD_label.zero_grad()
errD_lmask = errD_real_lmask + errD_fake_lmask
errD_lmask.backward()
optimizerD_label.step()

errD = errD_img + errD_lmask

############################
# (2) Update G network: maximize log(D(G(z)))
###########################

requires_grad(netG, True)
requires_grad(netD_image, False)
requires_grad(netD_label, False)
netG.zero_grad()

fake_img, fake_label = netG(patch)
label.fill_(real_label)  # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output_img = netD_image(fake_img).view(-1)
output_lmask = netD_label(fake_label).view(-1)
# Calculate G's loss based on this output
adv_loss_img = criterion(output_image, label)
adv_loss_lmask = criterion(output_lmask, label)
adv_loss = adv_loss_img+adv_loss_lmask
#print("output G: ", fake.shape)
rloss_img = mse(fake_img, real_cpu)
rloss_lmask = mse(fake_label, label_mask)
reconstructin_loss = rloss_img + rloss_lmask
errG = w_a * (adv_loss) + w_r * (reconstructin_loss)

# Calculate gradients for G
with torch.autograd.set_detect_anomaly(True):
    errG.backward()
# Update G
optimizerG.step()

I don’t know why, when launching the training script, the following error raises:

EPOCH 0/9
0it [00:00, ?it/s]/home/iaslab/GAN_venv/lib/python3.8/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in SigmoidBackward0. No forward pass information available. Enable detect anomaly during forward pass for more information. (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:92.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
0it [00:01, ?it/s]
Traceback (most recent call last):
  File "condDCGAN_label_2D.py", line 377, in <module>
    errG.backward()
  File "/home/iaslab/GAN_venv/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/iaslab/GAN_venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Thank you for your help!

In the generator update you are accidentally reusing a discriminator output:

output_img = netD_image(fake_img).view(-1)
output_lmask = netD_label(fake_label).view(-1)
# Calculate G's loss based on this output
adv_loss_img = criterion(output_image, label)

Note the adv_loss_img calculation, which uses output_image (created via output_image = netD_image(fake_img.detach()).view(-1) while netD_image was trainable) while output_img should be used.
Fix it and it should work.

1 Like