RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Any Suggestions Appreciated

I have a working code by img2img-turbo/src/train_cyclegan_turbo.py at main · GaParmar/img2img-turbo · GitHub. I tried adding a few lines of code which are highligted belo and after that I got the error. I know why I get the error, however, I couldnt figure out how to solve it. If I use retain_graph=True I get out of memory as expected. Is there any workaround ? I appreciate any suggestions. The code that produces the error is below.

for epoch in range(first_epoch, args.max_train_epochs):
    for step, batch in enumerate(train_dataloader):
        l_acc = [unet, net_disc_a, net_disc_b, vae_enc, vae_dec, init_pred_net]
        with accelerator.accumulate(*l_acc):
            burst_img = batch["pixel_values_src"].to(dtype=weight_dtype)
            print("burst_img shape", burst_img.shape, flush=True)
            # ADDED BY ME
            img_a = init_pred_net(burst_img) # initial prediction
            # ADDED BY ME
            del burst_img              
            print("img_pred shape after init pred", img_a.shape, flush=True)
            
            
            img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype)
            print("img_b shape", img_b.shape, flush=True)
            

            bsz = img_a.shape[0]
            fixed_a2b_emb = fixed_a2b_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
            fixed_b2a_emb = fixed_b2a_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
            timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * bsz, device=img_a.device).long()

            """
            Cycle Objective
            """
            # A -> fake B -> rec A
            cyc_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
            cyc_rec_a = CycleGAN_Turbo.forward_with_networks(cyc_fake_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
            loss_cycle_a = crit_cycle(cyc_rec_a, img_a) * args.lambda_cycle
            loss_cycle_a += net_lpips(cyc_rec_a, img_a).mean() * args.lambda_cycle_lpips
            # B -> fake A -> rec B
            cyc_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
            cyc_rec_b = CycleGAN_Turbo.forward_with_networks(cyc_fake_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
            loss_cycle_b = crit_cycle(cyc_rec_b, img_b) * args.lambda_cycle
            loss_cycle_b += net_lpips(cyc_rec_b, img_b).mean() * args.lambda_cycle_lpips
            # ADDED BY ME
            loss_init_pred = crit_init_pred(img_a, img_b)
            # MODIFIED BY ME
            accelerator.backward(loss_cycle_a + loss_cycle_b + loss_init_pred, retain_graph=False)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)

            optimizer_gen.step()
            lr_scheduler_gen.step()
            optimizer_gen.zero_grad()

            """
            Generator Objective (GAN) for task a->b and b->a (fake inputs)
            """
            fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
            fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
            loss_gan_a = net_disc_a(fake_b, for_G=True).mean() * args.lambda_gan
            loss_gan_b = net_disc_b(fake_a, for_G=True).mean() * args.lambda_gan
            accelerator.backward(loss_gan_a + loss_gan_b, retain_graph=False)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
            optimizer_gen.step()
            lr_scheduler_gen.step()
            optimizer_gen.zero_grad()
            optimizer_disc.zero_grad()

            """
            Identity Objective
            """
            idt_a = CycleGAN_Turbo.forward_with_networks(img_b, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
            loss_idt_a = crit_idt(idt_a, img_b) * args.lambda_idt
            loss_idt_a += net_lpips(idt_a, img_b).mean() * args.lambda_idt_lpips
            idt_b = CycleGAN_Turbo.forward_with_networks(img_a, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
            loss_idt_b = crit_idt(idt_b, img_a) * args.lambda_idt
            loss_idt_b += net_lpips(idt_b, img_a).mean() * args.lambda_idt_lpips
            loss_g_idt = loss_idt_a + loss_idt_b
            accelerator.backward(loss_g_idt, retain_graph=False)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
            optimizer_gen.step()
            lr_scheduler_gen.step()
            optimizer_gen.zero_grad()

            """
            Discriminator for task a->b and b->a (fake inputs)
            """
            loss_D_A_fake = net_disc_a(fake_b.detach(), for_real=False).mean() * args.lambda_gan
            loss_D_B_fake = net_disc_b(fake_a.detach(), for_real=False).mean() * args.lambda_gan
            loss_D_fake = (loss_D_A_fake + loss_D_B_fake) * 0.5
            accelerator.backward(loss_D_fake, retain_graph=False)
            if accelerator.sync_gradients:
                params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer_disc.step()
            lr_scheduler_disc.step()
            optimizer_disc.zero_grad()

            """
            Discriminator for task a->b and b->a (real inputs)
            """
            loss_D_A_real = net_disc_a(img_b, for_real=True).mean() * args.lambda_gan
            loss_D_B_real = net_disc_b(img_a, for_real=True).mean() * args.lambda_gan
            loss_D_real = (loss_D_A_real + loss_D_B_real) * 0.5
            accelerator.backward(loss_D_real, retain_graph=False)
            if accelerator.sync_gradients:
                params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer_disc.step()
            lr_scheduler_disc.step()
            optimizer_disc.zero_grad()