Hi,
I have a model for super resolution (ESRGAN) and I have the following error :
> Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Here is my training loop :
for i, sample in enumerate(train_loader):
hr = sample["hr"].to(config.device)
lr = sample["lr"].to(config.device)
# generated super resolution
sr = generator(lr)
###### train discriminator ######
for param in discriminator.parameters():
param.requires_grad = True
discriminator.zero_grad()
true_label = torch.full(size=(sr.shape[0], 1), fill_value=1.0, device=config.device)
fake_label = torch.full(size=(sr.shape[0], 1), fill_value=1.0, device=config.device)
predicted_true = discriminator(hr)
predicted_fake = discriminator(sr.detach())
d_loss_true = adversarial_criterion(torch.sigmoid(predicted_true - predicted_fake.mean(dim=0)), true_label)
d_loss_true.backward()
d_loss_fake = adversarial_criterion(torch.sigmoid(predicted_fake - predicted_true.mean(dim=0)), fake_label)
d_loss_fake.backward()
# optimization
d_loss = d_loss_fake + d_loss_true
d_optim.step()
###### train generator ######
for param in discriminator.parameters():
param.requires_grad = False
generator.zero_grad()
d_out_generated = discriminator(sr)
# mse/vgg loss
vgg_loss = vgg_criterion(sr, hr)
vgg_loss.backward()
# tricking the discriminator
adversarial_loss = config.adversarial_coefficient * adversarial_criterion(d_out_generated, true_label)
adversarial_loss.backward()
# l1 criterion
l1_loss = config.l1_coefficient * l1_criterion(sr, hr)
l1_loss.backward()
# relativistic loss
relativistic_loss = config.relativistic_coefficient * adversarial_criterion(torch.sigmoid(d_out_generated - predicted_true.mean(dim=0)), true_label)
relativistic_loss.backward()
# complete loss
g_loss = vgg_loss + adversarial_loss + l1_loss + relativistic_loss
# optimization
g_optim.step()
# writing with tensorboard
writer.add_scalar(f"{config.train_mode}/D_LOSS", d_loss, epoch*len(train_loader) + i + 1)
writer.add_scalar(f"{config.train_mode}/G_LOSS", g_loss, epoch*len(train_loader) + i + 1)
writer.add_scalar(f"{config.train_mode}/l1_loss", l1_loss, epoch*len(train_loader) + i + 1)
writer.add_scalar(f"{config.train_mode}/vgg_loss", vgg_loss, epoch*len(train_loader) + i + 1)
writer.add_scalar(f"{config.train_mode}/adversarial_loss", adversarial_loss, epoch*len(train_loader) + i + 1)
writer.add_scalar(f"{config.train_mode}/relativistic_loss", relativistic_loss, epoch*len(train_loader) + i + 1)
if i % 200 == 0 and i != 0:
print(f"EPOCH={epoch} [{i}/{len(train_loader)}]D_LOSS in {config.train_mode} mode : {d_loss} ")
print(f"EPOCH={epoch} [{i}/{len(train_loader)}]G_LOSS in {config.train_mode} : {g_loss} ")
Do you see where the error is ?
Thank you !