I am quite new in pyTorch and currently working with CycleGAN (pyTorch implementation) as a part of my project and I understand most of the implementation of cycleGAN.
I read the paper with the name ‘CycleGAN with better Cycles’ and I am trying to apply the modification which mentioned in the paper. One of modification is Cycle consistency weight decay which I don’t know how to apply.
optimizer_G.zero_grad() # Identity loss loss_id_A = criterion_identity(G_BA(real_A), real_A) loss_id_B = criterion_identity(G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss fake_B = G_AB(real_A) loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) fake_A = G_BA(real_B) loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle consistency loss recov_A = G_BA(fake_B) loss_cycle_A = criterion_cycle(recov_A, real_A) recov_B = G_AB(fake_A) loss_cycle_B = criterion_cycle(recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_GAN + lambda_cyc * loss_cycle + #lambda_cyc is 10 lambda_id * loss_identity #lambda_id is 0.5 * lambda_cyc loss_G.backward() optimizer_G.step()
My question is how can I gradually decay the weight of cycle consistency loss?
Any help for implementing this modification would be appreciated.
This is from paper:
Cycle consistency loss helps to stabilize training a lot in early stages but becomes an obstacle towards realistic images in later stages. We propose to gradually decay the weight of cycle consistency loss λ as training progress. However, we should still make sure that λ is
not decayed to 0 so that generators won’t become unconstrained and go completely wild.
Thanks in advance.