Trying to backward a second time (ESRGAN)

Hi,
I have a model for super resolution (ESRGAN) and I have the following error :

> Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Here is my training loop :

    for i, sample in enumerate(train_loader): 
        
        hr = sample["hr"].to(config.device) 
        lr = sample["lr"].to(config.device) 

        # generated super resolution 
        sr = generator(lr)

        ###### train discriminator ######


        for param in discriminator.parameters():
            param.requires_grad = True

        discriminator.zero_grad()

        true_label = torch.full(size=(sr.shape[0], 1), fill_value=1.0, device=config.device)
        fake_label = torch.full(size=(sr.shape[0], 1), fill_value=1.0, device=config.device)

        predicted_true = discriminator(hr)
        predicted_fake = discriminator(sr.detach())

        d_loss_true = adversarial_criterion(torch.sigmoid(predicted_true - predicted_fake.mean(dim=0)), true_label)
        d_loss_true.backward()
        d_loss_fake = adversarial_criterion(torch.sigmoid(predicted_fake - predicted_true.mean(dim=0)), fake_label)
        d_loss_fake.backward()
        # optimization 
        
        d_loss = d_loss_fake + d_loss_true
        d_optim.step()


        ###### train generator ######
        for param in discriminator.parameters():
            param.requires_grad = False
        
        generator.zero_grad()

        d_out_generated = discriminator(sr)
        # mse/vgg loss
        vgg_loss = vgg_criterion(sr, hr)
        vgg_loss.backward()
        # tricking the discriminator
        adversarial_loss = config.adversarial_coefficient * adversarial_criterion(d_out_generated, true_label) 
        adversarial_loss.backward()
        # l1 criterion
        l1_loss = config.l1_coefficient * l1_criterion(sr, hr)
        l1_loss.backward()
        # relativistic loss
        relativistic_loss = config.relativistic_coefficient * adversarial_criterion(torch.sigmoid(d_out_generated - predicted_true.mean(dim=0)), true_label)
        relativistic_loss.backward()

        # complete loss
        g_loss = vgg_loss + adversarial_loss + l1_loss + relativistic_loss
        # optimization 
        g_optim.step()

        # writing with tensorboard
        writer.add_scalar(f"{config.train_mode}/D_LOSS", d_loss, epoch*len(train_loader) + i + 1)
        writer.add_scalar(f"{config.train_mode}/G_LOSS", g_loss, epoch*len(train_loader) + i + 1)
        writer.add_scalar(f"{config.train_mode}/l1_loss", l1_loss, epoch*len(train_loader) + i + 1)
        writer.add_scalar(f"{config.train_mode}/vgg_loss", vgg_loss, epoch*len(train_loader) + i + 1)
        writer.add_scalar(f"{config.train_mode}/adversarial_loss", adversarial_loss, epoch*len(train_loader) + i + 1)
        writer.add_scalar(f"{config.train_mode}/relativistic_loss", relativistic_loss, epoch*len(train_loader) + i + 1)

        if i % 200 == 0 and i != 0: 
            print(f"EPOCH={epoch} [{i}/{len(train_loader)}]D_LOSS in {config.train_mode} mode : {d_loss} ")  
            print(f"EPOCH={epoch} [{i}/{len(train_loader)}]G_LOSS in {config.train_mode} : {g_loss} ")  

Do you see where the error is ?

Thank you ! :slight_smile:

Based on your code it seems you are trying to calculate the gradients in the generator multiple times via:

vgg_loss = vgg_criterion(sr, hr)
vgg_loss.backward()
...
adversarial_loss = config.adversarial_coefficient * adversarial_criterion(d_out_generated, true_label) 
adversarial_loss.backward()
...
l1_loss = config.l1_coefficient * l1_criterion(sr, hr)
l1_loss.backward()
...
relativistic_loss = config.relativistic_coefficient * adversarial_criterion(torch.sigmoid(d_out_generated - predicted_true.mean(dim=0)), true_label)
relativistic_loss.backward()

If I’m not mistaken, these losses were created by some output of the generator and would thus try to calculate the gradients of its parameters w.r.t. the current loss.
If so, then note that each backward() operation would delete the intermediate forward activations in the generator (which are needed for the gradient computation) if you are not using retain_graph=True. Use either this argument or add the losses to a final loss and call backward once.

1 Like

Hi,

Thank you for your answer.
As mentioned by you, I added all losses together but still got the error.
On the other hand, the error goes away when I remove the relativistic_loss to the sum.
Do you have any idea ?

Thank you once again

Maybe call .detach() on predicted_true as I assume you don’t want to use it to calculate the gradients in the discriminator again.

1 Like

Incredible, I just did that and went back to the forum to post the solution.
Thank you very much ! this whole detach() thing is confusing to me, I need to pay more attention :slight_smile:
Thank you !

1 Like