Hello Guys, I have been trying to train a conditional WGAN-GP network with a U-Net based Generator and a simple VGG type disrciminator for image to image translation. I am updating my discriminator 12 times for each update of Generator. I know that wgans are supposed to have stable training but in my case I don’t know what I am missing, I would really appreciate it if someone here help me out.
Thanks in advance.
-------------------training loop---------------------------
def train(aD, aG, opt_d, opt_g):
for iteration in range(START_ITER, END_ITER):
print("\niteration number: ", iteration)
start = time.time()
gen_loss = []
dis_loss = []
loader_1_iter = iter(loader_1)
loader_2_iter = iter(loader_2)
#---------------------TRAIN D------------------------
for i in range(CRITIC_ITERS):
x = next(loader_1_iter).cuda(c)
y = next(loader_2_iter).cuda(c)
aD.zero_grad()
# gen fake data and load real data
noise = torch.randn([x.shape[0], x.shape[1], x.shape[2], x.shape[3]]).cuda(c)
noise_x = torch.cat([noise, x], dim=1).cuda(c)
fake_data = aG(noise_x).detach()
fake_data_d = torch.cat([fake_data, x], dim=1)
y_real = torch.cat([y, x], dim=1).cuda(c)
# train with real data
disc_real = aD(y_real)
disc_real = torch.mean(disc_real)
# train with fake data
disc_fake = aD(fake_data_d)
disc_fake = torch.mean(disc_fake)
# train with interpolates data
gradient_penalty = calc_gradient_penalty(aD, y, fake_data, x)
# final disc cost
disc_cost = disc_fake - disc_real + gradient_penalty
disc_cost.backward()
opt_d.step()
dis_loss.append(disc_cost.item())
loader_1_iter = iter(loader_1)
loader_2_iter = iter(loader_2)
#---------------------TRAIN G------------------------
gen_cost = 0.0
for i in range(GENER_ITERS):
x = next(loader_1_iter).cuda(c)
aG.zero_grad()
noise = torch.randn([x.shape[0], x.shape[1], x.shape[2], x.shape[3]]).cuda(c)
noise_x = torch.cat([noise, x], dim=1).cuda(c)
noise.requires_grad_(True)
fake_data = aG(noise_x)
fake_data_d = torch.cat([fake_data, x], dim=1)
out = aD(fake_data_d)
gen_cost = (-1)*torch.mean(out)
gen_cost.backward()
opt_g.step()
gen_loss.append(gen_cost.item())
torchvision.utils.save_image(fake_data, "fake.jpg", normalize=True)
-----------------this is the gradient penalty function---------------------
def calc_gradient_penalty(netD, real_data, fake_data,x):
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous()
alpha = alpha.view(batch_size, 3, x.shape[2], x.shape[3])
alpha = alpha.cuda©
fake_data = fake_data.view(batch_size, 3, x.shape[2], x.shape[3])
interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
interpolates = interpolates.cuda(c)
interpolates.requires_grad_(True)
interpolates = torch.cat([interpolates, x], dim=1)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty