DCGAN - Trying to backward through the graph a second time

I know there are similar posts. I nearly checked them all. Tried .detach(). and retain_graph=True options. However I couldn’t solve it for my training loop of DCGAN. This is a fairly complex training loop for me so I wouold be glad for any help. I also checked the PyTorch tutorial and created this version.

criterion = nn.BCELoss()

discriminator_optimizer = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
generator_optimizer = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []
iters = 0

start = timeit.default_timer()
for epoch in tqdm(range(EPOCHS), position=0, leave=True):
    generator.train()
    discriminator.train()
    for idx, data in enumerate(tqdm(dataloader, position=0, leave=True)):
        img_data = data[0].to(device) # size [1024, 3, 64, 64]
        dummy_labels = data[1] # size [1024]
        
        real_labels = torch.full((dummy_labels.size()), 1., dtype=torch.float).to(device) # size [1024]
        fake_labels = torch.full((dummy_labels.size()), 0., dtype=torch.float).to(device)
        noise = torch.randn(dummy_labels.size()[0], INPUT_VECTOR_DIM, 1, 1).to(device) # size [1024, 100, 1, 1]
        
        discriminator.zero_grad() 
        
        discriminator_real_out = discriminator(img_data).view(-1).to(device) # size [1024] .view(-1) to rid unnecessary dimensions
        discriminator_real_loss = criterion(discriminator_real_out, real_labels)
        discriminator_real_loss.backward()
        
        generator_fake_out = generator(noise)  # size [1024, 3, 64, 64]
        discriminator_fake_out = discriminator(generator_fake_out.detach()).view(-1).to(device)
        discriminator_fake_loss = criterion(discriminator_fake_out, fake_labels)
        discriminator_fake_loss.backward()
        discriminator_loss = discriminator_real_loss.item() + discriminator_fake_loss.item()
        discriminator_optimizer.step()

        generator_loss = criterion(discriminator_fake_out, real_labels)
        generator_loss.backward()
        generator_optimizer.step()
        generator.zero_grad()

Problem is from this line generator_loss.backward(). Also when I remove the line discriminator_optimizer.step() the problem disappears.

Full error: RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Hi Uygar!

As a general principle, don’t just try things. Analyze the actual cause
of your specific issue and use the “option” that actually addresses that
cause.

Please look at the comments that I’ve added in line to your quoted code:

The key issue is that you are using discriminator_fake_out to
optimize both discriminator and generator (leading in this particular
attempt to the “backward a second time” error). If you try using
retain_graph = True without doing things just so, you are
likely to get an inplace-modification error.

Probably the simplest way to address this issue with training GANs is
to rebuild the discriminator computation graph by calling – at added
computational cost – discriminator_fake_out = discriminator (...)
twice, once for the discriminator_fake_loss.backward()
backpropagation and then again for the generator_loss.backward()
backpropagation.

Yes, as explained in the in-line comments I added to your code.

I don’t believe this. For the version of the code you posted, you will get the
“backward a second time” error after calling generator_loss.backward(),
regardless of whether discriminator_optimizer.step() was called or not.

Perhaps in a different version of your code you had an inplace-modification
error that removing discriminator_optimizer.step() appeared to fix.

If you have (or do) come across inplace-modification errors, it will be
because discriminator_optimizer.step() is modifying discriminator’s
parameters inplace.

A discussion about fixing inplace-modification errors that includes a
toy-GAN example can be found in this post:

Best.

K. Frank

1 Like

Hi Frank. Thank you for the comprehensive answer.

Great catch! the problem is because of the discriminator_fake_out since I use it at both discriminator and generator.

Calculating the computation graph for discriminator_fake_out for a second time solves the problem. I checked the PyTorch tutorial again and that’s also how they implemented it.

Now the first question comes into mind is can we solve it without calculating the discriminator_fake_out for a second time?

I did
discriminator_fake_out = torch.clone(discriminator_fake_out).detach().requires_grad_()
and the dummy version run without any problems. I’ll look at the actual results to see if it works.

Do you think this is also a correct approach?

Why would you want to, though? The results of the discriminator are different(i.e. better) after that first optimizerD.step(), and you want those better results to optimize the generator.

You’re right. However we’re feeding the same input generator_fake_out. So I couldn’t be sure if it worths it or becomes just an extra operation. What do you think?

Hi Uygar!

Yes. One approach to not recalculating discriminator_fake_out is
illustrated in the toy-GAN example in the post I linked to above.

The basic idea is to modify discriminator so that its forward pass uses
clone()s (where necessary) of its parameters because they get modified
inplace when you call discriminator_optimizer.step().

No. This will not backpropagate through discriminator (and will therefore
not backpropagate through generator). After calling .detach(), the new
tensor referred to by discriminator_fake_out is not part of any computation
graph, so any backpropagation stops at that point.

Calling .requires_grad_() does not reconnect discriminator_fake_out
to the computation graph. (As a general rule resetting requires_grad to
True to “fix” an issue is an error. Doing so can suppress reporting of some
error messages, but this doesn’t mean your code it correct.)

Best.

K. Frank