GAN two loss function error "Trying to backward through the graph a second time"

I made a 2d Pix2pixHD model variant 3d model.
Discriminator back-to-back errors occurred.
I’m trying to use detach, but I can’t decide where to use it.
help.
thank you


generator = pix2pixHD_3D.Generator_HD()

discriminator_m = pix2pixHD_3D.PatchDiscriminator_m()
discriminator_x2 = pix2pixHD_3D.PatchDiscriminator_x2()
discriminator_x4 = pix2pixHD_3D.PatchDiscriminator_x4()

def weights_init(module):
    if isinstance(module, nn.Conv3d):
        module.weight.detach().normal_(0.0, 0.02)

    elif isinstance(module, nn.BatchNorm3d):
        module.weight.detach().normal_(1.0, 0.02)
        module.bias.detach().fill_(0.0)

generator.apply(weights_init).cuda()
discriminator_m.apply(weights_init).cuda()
discriminator_x2.apply(weights_init).cuda()
discriminator_x4.apply(weights_init).cuda()

criterion_GAN = torch.nn.MSELoss()# 0.25 # torch.nn.BCELoss() 0.693
criterion_pixelwise = torch.nn.L1Loss() # torch.nn.SmoothL1Loss()

criterion_GAN.cuda()
criterion_pixelwise.cuda()

lr = 0.00005

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator_m.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_x2 = torch.optim.Adam(discriminator_x2.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_x4 = torch.optim.Adam(discriminator_x4.parameters(), lr=lr, betas=(0.5, 0.999))

def get_grid(input, is_real=True):
    if is_real:
        grid = torch.FloatTensor(input.shape).fill_(1.0)

    elif not is_real:
        grid = torch.FloatTensor(input.shape).fill_(0.0)

    return grid

import time

lambda_pixel = 100

start_time = time.time()# main version 01 3D

loss_hist = {'gen': [],
             'dis': [],
             'L1': [],
             'test_PSNR':[],
             'test_mse' : [],
             'test_ssim': []}

#
PATH = 'save_model/TEMP/'
patch_size = 16 #16 32

n_epochs = 1 #

dwon_2x = nn.Upsample(size=(64,64,32))
dwon_4x = nn.Upsample(size=(32,32,16))


for epoch in range(n_epochs):
    for ii, batch in enumerate(train_dataloader):

        rrer = qwerqw
        condition_im = batch[0].cuda()
        real_im = batch[1].cuda()


        optimizer_G.zero_grad()

        # 이미지 생성
        fake_im = generator(condition_im, dwon_2x(condition_im)) #


        f_out_dis = discriminator_m(fake_im.detach(), condition_im) 
        #r_out_dis = discriminator_m(real_im, condition_im)

        f_out_dis2 = discriminator_x2(dwon_2x(fake_im).detach(), dwon_2x(condition_im)) # 2x down
        #r_out_dis2 = discriminator_x2(dwon_2x(real_im), dwon_2x(condition_im)) # 2x down

        f_out_dis3 = discriminator_x4(dwon_4x(fake_im).detach(), dwon_4x(condition_im)) # 4x down
        #r_out_dis3 = discriminator_x4(dwon_4x(real_im), dwon_4x(condition_im)) # 4x down

        loss_G = 0
        f_loss_D = 0
        r_loss_D = 0
        for i in range(len(f_out_dis)):

            #rrer = qwerqw
            real_grid = get_grid(f_out_dis[i], is_real=True).cuda()
            fake_grid = get_grid(f_out_dis[i], is_real=False).cuda()

            #f_loss_D += criterion_GAN(f_out_dis[i], fake_grid)
            #r_loss_D += criterion_GAN(r_out_dis[i], real_grid)
            loss_G += criterion_GAN(f_out_dis[i], real_grid)

        for i in range(len(f_out_dis2)):
            # rrer = qwerqw
            real_grid = get_grid(f_out_dis2[i], is_real=True).cuda()
            fake_grid = get_grid(f_out_dis2[i], is_real=False).cuda()

            #f_loss_D += criterion_GAN(f_out_dis2[i], fake_grid)
            #r_loss_D += criterion_GAN(r_out_dis2[i], real_grid)
            loss_G += criterion_GAN(f_out_dis2[i], real_grid)

        for i in range(len(f_out_dis3)):
            # rrer = qwerqw
            real_grid = get_grid(f_out_dis3[i], is_real=True).cuda()
            fake_grid = get_grid(f_out_dis3[i], is_real=False).cuda()

            #f_loss_D += criterion_GAN(f_out_dis3[i], fake_grid)
            #r_loss_D += criterion_GAN(r_out_dis3[i], real_grid)
            loss_G += criterion_GAN(f_out_dis3[i], real_grid)

        loss_G_FM = criterion_pixelwise(fake_im, real_im.detach())

        all_d_len = len(f_out_dis)+len(f_out_dis2)+len(f_out_dis3)

        loss_G_f = loss_G * (1.0 / all_d_len) + loss_G_FM * lambda_pixel

        loss_G_f.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        optimizer_D_x2.zero_grad()
        optimizer_D_x4.zero_grad()


        r_out_dis = discriminator_m(real_im, condition_im) # discriminator : 가짜 이미지 평가 0 ~ 1
        r_out_dis2 = discriminator_x2(dwon_2x(real_im), dwon_2x(condition_im)) # 2x down
        r_out_dis3 = discriminator_x4(dwon_4x(real_im), dwon_4x(condition_im)) # 4x down

        for i in range(len(f_out_dis)):

            #rrer = qwerqw
            real_grid = get_grid(f_out_dis[i], is_real=True).cuda()
            fake_grid = get_grid(f_out_dis[i], is_real=False).cuda()

            f_loss_D += criterion_GAN(f_out_dis[i], fake_grid)
            r_loss_D += criterion_GAN(r_out_dis[i], real_grid)

        for i in range(len(f_out_dis2)):
            # rrer = qwerqw
            real_grid = get_grid(f_out_dis2[i], is_real=True).cuda()
            fake_grid = get_grid(f_out_dis2[i], is_real=False).cuda()

            f_loss_D += criterion_GAN(f_out_dis2[i], fake_grid)
            r_loss_D += criterion_GAN(r_out_dis2[i], real_grid)

        for i in range(len(f_out_dis3)):
            # rrer = qwerqw
            real_grid = get_grid(f_out_dis3[i], is_real=True).cuda()
            fake_grid = get_grid(f_out_dis3[i], is_real=False).cuda()

            f_loss_D += criterion_GAN(f_out_dis3[i], fake_grid)
            r_loss_D += criterion_GAN(r_out_dis3[i], real_grid)


        loss_D = (f_loss_D + r_loss_D)/(2*all_d_len)


        loss_D.backward() @@ error point 
        optimizer_D.step()
        optimizer_D_x2.step()
        optimizer_D_x4.step()


What about passing retain_graph=True to the second backward?

Thank you for answer.

As you said, I get the same error if I put retain_graph in the second backward.

Oh sorry, I actually mean the first backward, whoops! I don’t know why I said second backward.