Hi,
I am trying to implement SRGAN, I have looked at many implementation on Github and none of them used the retain_graph = True option.
I have the following error :
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.
Here is my training Loop :
# training loop
torch.autograd.set_detect_anomaly(True)
dnet.train()
gnet.train()
discriminator_accuracy = 0
generator_accuracy = 0
d_loss = 0
g_loss = 0
for epoch in range(epochs):
for i, sample in enumerate(data_loader):
hr, lr = sample["hr"], sample["lr"]
hr = hr.to(device)
lr = lr.to(device)
# TODO : add noise
generated_sr = gnet(lr)
true_labels = torch.full(size=(batch_size,1), fill_value=1, device=device, dtype=torch.float32)
################# TRAIN DISCRIMINATOR ###############
d_optim.zero_grad()
for p in dnet.parameters():
p.requires_grad = True
true_preds = dnet(hr)
true_loss = bce_criterion(true_preds, true_labels)
true_loss.backward()
fake_labels = true_labels.fill_(0)
fake_preds = dnet(generated_sr.detach())
discriminator_accuracy += compute_accuracy(fake_preds, fake_labels)
fake_loss = bce_criterion(fake_preds, fake_labels)
fake_loss.backward()
# update parameters
d_loss += fake_loss + true_loss
# zero grad
d_optim.step()
################# TRAIN GENERATOR ####################
true_labels = fake_labels.fill_(1)
for p in dnet.parameters():
p.requires_grad = False
# zero grad
# tricking the discriminator
dg_out = dnet(generated_sr)
# compute total loss regarding the generator
g_bce_loss = 1e-3 * bce_criterion(dg_out,true_labels)
# content loss
content_loss = 1.0 * content_criterion(generated_sr, hr.detach())
#pixel_loss = 1.0 * mse_criterion(generated_sr, hr.detach())
generator_accuracy += compute_accuracy(dg_out, true_labels)
# update parameters
g_loss += content_loss + g_bce_loss + pixel_loss
g_optim.zero_grad()
g_loss.backward()
g_optim.step()
if i % 20 == 0 and i != 0:
writer.add_scalar("Loss/discriminator", d_loss / 20, i)
writer.add_scalar("Loss/generator", g_loss / 20, i)
writer.add_scalar("Accuracy/discriminator", discriminator_accuracy / 20, i)
writer.add_scalar("Accuracy/generator", generator_accuracy / 20, i)
print(f"EPOCH={epoch} | BATCH={i} | GLOSS={g_loss / 20} | DLOSS={d_loss / 20} | ACC_DISC={discriminator_accuracy / 20} | ACC_GENER={generator_accuracy/20}")
d_loss = 0
g_loss = 0
discriminator_accuracy = 0
generator_accuracy = 0
if i % 100 == 0:
with torch.no_grad():
downsampled = lr[0]
generated = generated_sr[0]
gt = hr[0]
show(downsampled)
show(generated)
show(gt)
thank you very much !