Trying to backward through the graph a second time SRGAN

Hi,

I am trying to implement SRGAN, I have looked at many implementation on Github and none of them used the retain_graph = True option.

I have the following error :

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

Here is my training Loop :

# training loop 
torch.autograd.set_detect_anomaly(True)
dnet.train()
gnet.train()

discriminator_accuracy = 0 
generator_accuracy = 0
d_loss = 0 
g_loss = 0

for epoch in range(epochs): 
    
    for i, sample in enumerate(data_loader): 
        hr, lr = sample["hr"], sample["lr"] 
        hr = hr.to(device)
        lr = lr.to(device)
        # TODO : add noise
        
        generated_sr = gnet(lr)
        true_labels = torch.full(size=(batch_size,1), fill_value=1, device=device, dtype=torch.float32)
        
        ################# TRAIN DISCRIMINATOR ###############
        d_optim.zero_grad()
        for p in dnet.parameters():
            p.requires_grad = True
        
        
        
        true_preds = dnet(hr)
        true_loss = bce_criterion(true_preds, true_labels)  
        true_loss.backward()
        
        
        fake_labels = true_labels.fill_(0) 
        fake_preds = dnet(generated_sr.detach())
        
        discriminator_accuracy += compute_accuracy(fake_preds, fake_labels)
        
        
        fake_loss = bce_criterion(fake_preds, fake_labels)
        fake_loss.backward()
        # update parameters 
        d_loss += fake_loss + true_loss 
        # zero grad 
        d_optim.step()

        ################# TRAIN GENERATOR ####################
        true_labels = fake_labels.fill_(1)
        for p in dnet.parameters():
            p.requires_grad = False
            
        # zero grad 
        # tricking the discriminator
        dg_out = dnet(generated_sr)
        # compute total loss regarding the generator
        g_bce_loss = 1e-3 * bce_criterion(dg_out,true_labels)
        # content loss
        content_loss = 1.0 * content_criterion(generated_sr, hr.detach())
        #pixel_loss = 1.0 * mse_criterion(generated_sr, hr.detach())
        
        generator_accuracy += compute_accuracy(dg_out, true_labels)
        
        # update parameters 
        g_loss += content_loss + g_bce_loss + pixel_loss
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()
        
        
        
        
        if i % 20 == 0 and i != 0: 
            writer.add_scalar("Loss/discriminator", d_loss / 20, i)
            writer.add_scalar("Loss/generator", g_loss / 20, i)
            writer.add_scalar("Accuracy/discriminator", discriminator_accuracy / 20, i)
            writer.add_scalar("Accuracy/generator", generator_accuracy / 20, i)
            print(f"EPOCH={epoch} | BATCH={i} | GLOSS={g_loss / 20} | DLOSS={d_loss / 20} | ACC_DISC={discriminator_accuracy / 20} | ACC_GENER={generator_accuracy/20}")
            
            d_loss = 0 
            g_loss = 0 
            discriminator_accuracy = 0 
            generator_accuracy = 0
            
        if i % 100 == 0: 
            with torch.no_grad(): 
                downsampled = lr[0]
                generated = generated_sr[0]
                gt = hr[0]
                
                show(downsampled)
                show(generated) 
                show(gt)

thank you very much !

You are accumulating the losses in:

d_loss += fake_loss + true_loss 
...
g_loss += content_loss + g_bce_loss + pixel_loss

which will keep the computation graphs alive and Autograd will thus try to backpropagate through the computation graphs from all iterations.
I guess this behavior is not intended so you could assign d_/g_loss to the newly computed iteration loss instead.

1 Like

Fixed the entire thing ! thx :slight_smile: