I made a 2d Pix2pixHD model variant 3d model.
Discriminator back-to-back errors occurred.
I’m trying to use detach, but I can’t decide where to use it.
help.
thank you
generator = pix2pixHD_3D.Generator_HD()
discriminator_m = pix2pixHD_3D.PatchDiscriminator_m()
discriminator_x2 = pix2pixHD_3D.PatchDiscriminator_x2()
discriminator_x4 = pix2pixHD_3D.PatchDiscriminator_x4()
def weights_init(module):
if isinstance(module, nn.Conv3d):
module.weight.detach().normal_(0.0, 0.02)
elif isinstance(module, nn.BatchNorm3d):
module.weight.detach().normal_(1.0, 0.02)
module.bias.detach().fill_(0.0)
generator.apply(weights_init).cuda()
discriminator_m.apply(weights_init).cuda()
discriminator_x2.apply(weights_init).cuda()
discriminator_x4.apply(weights_init).cuda()
criterion_GAN = torch.nn.MSELoss()# 0.25 # torch.nn.BCELoss() 0.693
criterion_pixelwise = torch.nn.L1Loss() # torch.nn.SmoothL1Loss()
criterion_GAN.cuda()
criterion_pixelwise.cuda()
lr = 0.00005
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator_m.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_x2 = torch.optim.Adam(discriminator_x2.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_x4 = torch.optim.Adam(discriminator_x4.parameters(), lr=lr, betas=(0.5, 0.999))
def get_grid(input, is_real=True):
if is_real:
grid = torch.FloatTensor(input.shape).fill_(1.0)
elif not is_real:
grid = torch.FloatTensor(input.shape).fill_(0.0)
return grid
import time
lambda_pixel = 100
start_time = time.time()# main version 01 3D
loss_hist = {'gen': [],
'dis': [],
'L1': [],
'test_PSNR':[],
'test_mse' : [],
'test_ssim': []}
#
PATH = 'save_model/TEMP/'
patch_size = 16 #16 32
n_epochs = 1 #
dwon_2x = nn.Upsample(size=(64,64,32))
dwon_4x = nn.Upsample(size=(32,32,16))
for epoch in range(n_epochs):
for ii, batch in enumerate(train_dataloader):
rrer = qwerqw
condition_im = batch[0].cuda()
real_im = batch[1].cuda()
optimizer_G.zero_grad()
# 이미지 생성
fake_im = generator(condition_im, dwon_2x(condition_im)) #
f_out_dis = discriminator_m(fake_im.detach(), condition_im)
#r_out_dis = discriminator_m(real_im, condition_im)
f_out_dis2 = discriminator_x2(dwon_2x(fake_im).detach(), dwon_2x(condition_im)) # 2x down
#r_out_dis2 = discriminator_x2(dwon_2x(real_im), dwon_2x(condition_im)) # 2x down
f_out_dis3 = discriminator_x4(dwon_4x(fake_im).detach(), dwon_4x(condition_im)) # 4x down
#r_out_dis3 = discriminator_x4(dwon_4x(real_im), dwon_4x(condition_im)) # 4x down
loss_G = 0
f_loss_D = 0
r_loss_D = 0
for i in range(len(f_out_dis)):
#rrer = qwerqw
real_grid = get_grid(f_out_dis[i], is_real=True).cuda()
fake_grid = get_grid(f_out_dis[i], is_real=False).cuda()
#f_loss_D += criterion_GAN(f_out_dis[i], fake_grid)
#r_loss_D += criterion_GAN(r_out_dis[i], real_grid)
loss_G += criterion_GAN(f_out_dis[i], real_grid)
for i in range(len(f_out_dis2)):
# rrer = qwerqw
real_grid = get_grid(f_out_dis2[i], is_real=True).cuda()
fake_grid = get_grid(f_out_dis2[i], is_real=False).cuda()
#f_loss_D += criterion_GAN(f_out_dis2[i], fake_grid)
#r_loss_D += criterion_GAN(r_out_dis2[i], real_grid)
loss_G += criterion_GAN(f_out_dis2[i], real_grid)
for i in range(len(f_out_dis3)):
# rrer = qwerqw
real_grid = get_grid(f_out_dis3[i], is_real=True).cuda()
fake_grid = get_grid(f_out_dis3[i], is_real=False).cuda()
#f_loss_D += criterion_GAN(f_out_dis3[i], fake_grid)
#r_loss_D += criterion_GAN(r_out_dis3[i], real_grid)
loss_G += criterion_GAN(f_out_dis3[i], real_grid)
loss_G_FM = criterion_pixelwise(fake_im, real_im.detach())
all_d_len = len(f_out_dis)+len(f_out_dis2)+len(f_out_dis3)
loss_G_f = loss_G * (1.0 / all_d_len) + loss_G_FM * lambda_pixel
loss_G_f.backward()
optimizer_G.step()
optimizer_D.zero_grad()
optimizer_D_x2.zero_grad()
optimizer_D_x4.zero_grad()
r_out_dis = discriminator_m(real_im, condition_im) # discriminator : 가짜 이미지 평가 0 ~ 1
r_out_dis2 = discriminator_x2(dwon_2x(real_im), dwon_2x(condition_im)) # 2x down
r_out_dis3 = discriminator_x4(dwon_4x(real_im), dwon_4x(condition_im)) # 4x down
for i in range(len(f_out_dis)):
#rrer = qwerqw
real_grid = get_grid(f_out_dis[i], is_real=True).cuda()
fake_grid = get_grid(f_out_dis[i], is_real=False).cuda()
f_loss_D += criterion_GAN(f_out_dis[i], fake_grid)
r_loss_D += criterion_GAN(r_out_dis[i], real_grid)
for i in range(len(f_out_dis2)):
# rrer = qwerqw
real_grid = get_grid(f_out_dis2[i], is_real=True).cuda()
fake_grid = get_grid(f_out_dis2[i], is_real=False).cuda()
f_loss_D += criterion_GAN(f_out_dis2[i], fake_grid)
r_loss_D += criterion_GAN(r_out_dis2[i], real_grid)
for i in range(len(f_out_dis3)):
# rrer = qwerqw
real_grid = get_grid(f_out_dis3[i], is_real=True).cuda()
fake_grid = get_grid(f_out_dis3[i], is_real=False).cuda()
f_loss_D += criterion_GAN(f_out_dis3[i], fake_grid)
r_loss_D += criterion_GAN(r_out_dis3[i], real_grid)
loss_D = (f_loss_D + r_loss_D)/(2*all_d_len)
loss_D.backward() @@ error point
optimizer_D.step()
optimizer_D_x2.step()
optimizer_D_x4.step()