I am trying to train SRGAN from scratch. I have read solutions for this type of problem, but it would be great if someone could help me debug my code.
gen_model = Generator().to(device, non_blocking=True)
disc_model = Discriminator().to(device, non_blocking=True)
opt_gen = optim.Adam(gen_model.parameters(), lr=0.01)
opt_disc = optim.Adam(disc_model.parameters(), lr=0.01)
from torch.nn.modules.loss import BCELoss
def train_model(gen, disc):
for epoch in range(20):
run_loss_disc = 0
run_loss_gen = 0
for data in train:
low_res, high_res = data[0].to(device, non_blocking=True, dtype=torch.float).permute(0, 3, 1, 2),data[1].to(device, non_blocking=True, dtype=torch.float).permute(0, 3, 1, 2)
#--------Discriminator-----------------
gen_image = gen(low_res)
gen_image = gen_image.detach()
disc_gen = disc(gen_image)
disc_real = disc(high_res)
p=nn.BCEWithLogitsLoss()
loss_gen = p(disc_real, torch.ones_like(disc_real))
loss_real = p(disc_gen, torch.zeros_like(disc_gen))
loss_disc = loss_gen + loss_real
opt_disc.zero_grad()
loss_disc.backward()
run_loss_disc+=loss_disc
#---------Generator--------------------
cont_loss = vgg_loss(high_res, gen_image)
adv_loss = 1e-3*p(disc_gen, torch.ones_like(disc_gen))
gen_loss = cont_loss+(10^-3)*adv_loss
opt_gen.zero_grad()
gen_loss.backward()
opt_disc.step()
opt_gen.step()
run_loss_gen+=gen_loss
print("Run Loss Discriminator: %d", run_loss_disc)
print("Run Loss Generator: %d", run_loss_gen)
train_model(gen_model, disc_model)
Thanks