Hi,
I am trying to implement SRGAN (super resolution).
I try to train my model on a batch of 10 images, just in order to see if I am able to correctly overfit.
Here are the results after 200 epochs on these 10 images :
I highly suspect my training loop, and my non understanding of how the gradient flows with pytorch.
Here is the code :
torch.autograd.set_detect_anomaly(True)
for epoch in range(epochs):
for i, batch in enumerate(train_data):
lr, hr = batch["lr"], batch["hr"]
lr = lr.to(device)
hr = hr.to(device)
bs = lr.shape[0]
# train discriminator
label = torch.full((bs, ), 1, device=device, dtype=torch.float32)
d_net.zero_grad()
real_pred = d_net(hr).view(-1)
real_loss = bce_criterion(real_pred,label)
real_loss.backward()
# change label
label = label.fill_(0.0)
sr = g_net(lr)
fake_pred = d_net(sr.detach()).view(-1)
fake_loss = bce_criterion(fake_pred, label)
fake_loss.backward()
d_loss = fake_loss + real_loss
d_optim.step()
with torch.no_grad():
p_real = real_pred.mean(axis=0)
p_fake = fake_pred.mean(axis=0)
# train generator
# change label
g_net.zero_grad()
label = label.fill_(1.)
sr_preds = d_net(sr).view(-1)
adv_loss = adv_coeff * bce_criterion(sr_preds, label)
adv_loss.backward(retain_graph=True)
content_criterion = content_loss(hr, sr)
content_criterion.backward()
g_loss = content_criterion + adv_loss
g_optim.step()
with torch.no_grad():
p_fake_trick = sr_preds.mean(axis=0)
# monitoring
if i % 10 == 0:
print("d_loss = ", d_loss.item())
print("g_loss = ", g_loss.item())
print("D(Real) = ", p_real.item())
print("D(Fake)2 = ", p_fake_trick.item())
print("Content loss = ", content_criterion.item())
with torch.no_grad():
for test_batch in train_data:
lr_test, hr_test = test_batch["lr"], test_batch["hr"]
lr_test = lr_test.to(device)
hr_test = hr_test.to(device)
print("LR")
display_grid(lr_test.cpu())
sr = g_net(lr_test)
print("Generated")
display_grid(sr.cpu())
print("Real")
display_grid(hr_test.cpu())
break
print("Epoch : ", epoch)
Thank you very very much for your help