I have tried many methods but nothing seems to work. As a newbie, I really don’t understand how to modify the code to make it work. Did I modify it wrong?
torch.autograd.set_detect_anomaly(True)
# 1st path update D_xvs
d_x_loss = x_bar_bar_loss_s - x_loss_s + 10. * gp_loss + v_loss_x
D_xvs_solver.zero_grad()
d_x_loss.backward()
D_xvs_solver.step()
# 2nd path, update G_xvz
x_bar_bar_loss_v, x_bar_bar_loss_s = D_xvs(x_bar_bar) #
x_bar_bar_loss_s = x_bar_bar_loss_s.mean()
if reconstruct_fake:
x_l1_loss = L1_loss(x_bar_bar, x_bar)
v_loss_x_bar_bar = crossEntropyLoss(x_bar_bar_loss_v, vv1) # ACGAN loss of x_bar_bar(v1)
else:
x_l1_loss = L1_loss(x_bar_bar, x2)
v_loss_x_bar_bar = crossEntropyLoss(x_bar_bar_loss_v, vv2) # ACGAN loss of x_bar_bar(v2)
# id loss
v_loss_x = crossEntropyLoss(v_bar, vv1)
if not reconstruct_fake:
l_g_rec = ID_weight * cri_rec(
fake_fea, real_fea,
torch.ones((real_fea.shape[0]), device=device) #
)
symmetry_128_loss = args.symmetry_loss_weight * Sym_loss(x_bar_bar)
g_loss = -x_bar_bar_loss_s + 4 * x_l1_loss + v_loss_x_bar_bar + 0.01 * v_loss_x + l_g_rec + symmetry_128_loss
else:
g_loss = -x_bar_bar_loss_s + 4 * x_l1_loss + v_loss_x_bar_bar + 0.01 * v_loss_x
G_vzx_solver.zero_grad()
G_xvz_solver.zero_grad()
g_loss.backward()
if not reconstruct_fake:
G_vzx_solver.step()
G_xvz_solver.step()