RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad()

I am trying to train SRGAN from scratch. I have read solutions for this type of problem, but it would be great if someone could help me debug my code.

gen_model = Generator().to(device, non_blocking=True)
disc_model  = Discriminator().to(device, non_blocking=True)
opt_gen = optim.Adam(gen_model.parameters(), lr=0.01)
opt_disc = optim.Adam(disc_model.parameters(), lr=0.01)
from torch.nn.modules.loss import BCELoss

def train_model(gen, disc):
  for epoch in range(20):
    run_loss_disc = 0
    run_loss_gen = 0
    for data in train:
      low_res, high_res = data[0].to(device, non_blocking=True, dtype=torch.float).permute(0, 3, 1, 2),data[1].to(device, non_blocking=True, dtype=torch.float).permute(0, 3, 1, 2)
      #--------Discriminator-----------------
     
      gen_image = gen(low_res)
      gen_image = gen_image.detach()
      disc_gen = disc(gen_image)
      disc_real = disc(high_res)
      p=nn.BCEWithLogitsLoss()
      loss_gen = p(disc_real, torch.ones_like(disc_real))
      
      loss_real = p(disc_gen, torch.zeros_like(disc_gen))
      loss_disc = loss_gen + loss_real
      opt_disc.zero_grad()
      loss_disc.backward()
      
      run_loss_disc+=loss_disc
      #---------Generator--------------------
      cont_loss = vgg_loss(high_res, gen_image)
      adv_loss = 1e-3*p(disc_gen, torch.ones_like(disc_gen))
      gen_loss = cont_loss+(10^-3)*adv_loss
      opt_gen.zero_grad()
      gen_loss.backward()
      opt_disc.step()
      opt_gen.step()
      run_loss_gen+=gen_loss
    print("Run Loss Discriminator: %d", run_loss_disc)
    print("Run Loss Generator: %d", run_loss_gen)

train_model(gen_model, disc_model)

Thanks

This line of code:

run_loss_disc+=loss_disc

looks wrong as you are accumulating the current discriminator loss including the entire computation graph.
If you want to accumulate the loss value only for printing purposes, use run_loss_dict += loss_dict.item() or .detach() the tensor.

Hello. Tried this. I still face the same error.