DCGAN for super resolution training loop returns "Trying to backward through the graph a second time"

Hi,

I am trying to build a model that does super-resolution based on this idea : https://github.com/david-gpu/srez

It is almost like a DCGAN, but instead of having a noise of size (100,1,1) as input for the generator, you provide the 16x16 downscaled true image (from my understanding). In order to compute the loss of the generator, I need to :

  1. trick the discriminator (as in DCGAN)
  2. compute the L1 loss between the downscaled real image and the downscaled generated image (from my understanding)

I have the following error :

Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I do not know how to fix this issue, but my intuition tells me that it has to do with the use of .detach(). I had a similar issue when implementing DCGAN. But I don’t know how to fix it here.

Here is my training loop :

for epoch in range(epochs): 
    for i, data in enumerate(data_loader): 
        
        real_images, _ = data 
        real_images = real_images.to(device)
        # train discriminator
        discriminator.zero_grad()
        true_labels = torch.full(size=(batch_size,1), fill_value=1.0, device=device)
        real_preds = discriminator(real_images)
        real_loss = bce_criterion(true_labels, real_preds)
        real_loss.backward() # compute derivative
        
        
        # down scaling
        print("downscaling the image")
        downscaled_image = torch.nn.Upsample((ds, ds))(real_images)
        print(downscaled_image.shape)
            
            
        fake_labels = true_labels.fill_(0.0) 
        generated_images = generator(downscaled_image)
        generated_preds = discriminator(generated_images.detach())
        fake_loss = bce_criterion(generated_preds, fake_labels)
        fake_loss.backward() 
        
        discriminator_loss = fake_loss + real_loss
        d_optim.step()
        
        # training the generator
        generator.zero_grad()
        true_labels = fake_labels.fill_(1.0)# tricking the discriminator
        dg_out = discriminator(generated_images)
        generator_bce_loss = bce_criterion(true_labels, dg_out)
        generator_bce_loss.backward() 
        
        #down scaling the generated image 
        downscaled_generated = torch.nn.Upsample((ds, ds))(generated_images)
        l1_loss = l1smooth_criterion(downscaled_generated, downscaled_image)
        l1_loss.backward() # compute the derivates
        
        generator_loss = l1_loss + generator_bce_loss
        # update params 
        g_optim.step()
        
        
        if i % 20 == 0: 
            print(f"Epoch : {epoch + 1} | Batch : {i+1} | D loss : {discriminator_loss} | G loss : {generator_loss}")
            writer.add_scalar('discriminator loss', d_loss / 20, epoch * len(data_loader) + i)
            writer.add_scalar('generator loss', g_loss / 20, epoch * len(data_loader) + i)
            plot_image(grid[:10])
            plt.show()

Thank you :slight_smile:

The error is most likely caused in these lines of code:

        # training the generator
        generator.zero_grad()
        true_labels = fake_labels.fill_(1.0)# tricking the discriminator
        dg_out = discriminator(generated_images)
        generator_bce_loss = bce_criterion(true_labels, dg_out)
        generator_bce_loss.backward() 
        
        #down scaling the generated image 
        downscaled_generated = torch.nn.Upsample((ds, ds))(generated_images)
        l1_loss = l1smooth_criterion(downscaled_generated, downscaled_image)
        l1_loss.backward() # compute the derivates

generated_images is returned by generator and used in two separate computation graphs:

  • via dg_out = discriminator(generated_images)
  • via downscaled_generated = torch.nn.Upsample((ds, ds))(generated_images)

Both outputs are calculating a loss and you are calling backward on them.
The first backward call will free the intermediate forward activations in the generator and the second one will thus fail.
Sum these losses and call backward once or use retain_graph=True in the first backward call.

1 Like

Hi !
It indeed fixed the problem !
Thank you very much for your help <3