I have implemented pix2pix GAN from scratch. And I ran into a problem with loss on generator. It is rising(visualisation of losses on valid and train). As backbone I have implemented UNet(It is too big to post it here, but if it would be helpful I can insert it).
Why it can be so?
It is my train section:
def set_grad(model, flag):
for parameter in model.parameters():
parameter.requires_grad = flag
for epoch in range(epochs):
avg_gen_train_loss = 0
avg_disc_train_loss = 0
disc_model.train()
gen_model.train()
for X_batch, Y_batch in data_tr:
X_batch, Y_batch = X_batch.to(device), Y_batch.to(device) # data to device
# generate image
Y_batch_gen = gen_model(X_batch)
# discriminator
#--------------------------------------------
set_grad(disc_model, True)
disc_optim.zero_grad()
X_Y_gen = torch.cat((X_batch, Y_batch_gen), 1)
d_gen = disc_model(X_Y_gen.detach())
loss_gen = torch.nn.BCEWithLogitsLoss()(d_gen, torch.Tensor([0.]).expand_as(d_gen).to(device))
X_Y = torch.cat((X_batch, Y_batch), 1)
d_real = disc_model(X_Y)
loss_real = torch.nn.BCEWithLogitsLoss()(d_real, torch.Tensor([1.]).expand_as(d_real).to(device))
disc_loss = (loss_gen + loss_real) * 0.5
avg_disc_train_loss += disc_loss
disc_loss.backward()
disc_optim.step()
#--------------------------------------------
# generator
#--------------------------------------------
set_grad(disc_model, False)
gen_optim.zero_grad()
d_gen = disc_model(X_Y_gen)
loss_gen = torch.nn.BCEWithLogitsLoss()(d_gen, torch.Tensor([1.]).expand_as(d_gen).to(device))
loss_l1 = nn.L1Loss()(Y_batch_gen, Y_batch)
loss_gen = loss_gen + lambda_ * loss_l1
avg_gen_train_loss += loss_gen
loss_gen.backward()
gen_optim.step()
#--------------------------------------------