Hi,
I am trying to build a model that does super-resolution based on this idea : GitHub - david-gpu/srez: Image super-resolution through deep learning
It is almost like a DCGAN, but instead of having a noise of size (100,1,1) as input for the generator, you provide the 16x16 downscaled true image (from my understanding). In order to compute the loss of the generator, I need to :
- trick the discriminator (as in DCGAN)
- compute the L1 loss between the downscaled real image and the downscaled generated image (from my understanding)
I have the following error :
Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
I do not know how to fix this issue, but my intuition tells me that it has to do with the use of .detach(). I had a similar issue when implementing DCGAN. But I don’t know how to fix it here.
Here is my training loop :
for epoch in range(epochs):
for i, data in enumerate(data_loader):
real_images, _ = data
real_images = real_images.to(device)
# train discriminator
discriminator.zero_grad()
true_labels = torch.full(size=(batch_size,1), fill_value=1.0, device=device)
real_preds = discriminator(real_images)
real_loss = bce_criterion(true_labels, real_preds)
real_loss.backward() # compute derivative
# down scaling
print("downscaling the image")
downscaled_image = torch.nn.Upsample((ds, ds))(real_images)
print(downscaled_image.shape)
fake_labels = true_labels.fill_(0.0)
generated_images = generator(downscaled_image)
generated_preds = discriminator(generated_images.detach())
fake_loss = bce_criterion(generated_preds, fake_labels)
fake_loss.backward()
discriminator_loss = fake_loss + real_loss
d_optim.step()
# training the generator
generator.zero_grad()
true_labels = fake_labels.fill_(1.0)# tricking the discriminator
dg_out = discriminator(generated_images)
generator_bce_loss = bce_criterion(true_labels, dg_out)
generator_bce_loss.backward()
#down scaling the generated image
downscaled_generated = torch.nn.Upsample((ds, ds))(generated_images)
l1_loss = l1smooth_criterion(downscaled_generated, downscaled_image)
l1_loss.backward() # compute the derivates
generator_loss = l1_loss + generator_bce_loss
# update params
g_optim.step()
if i % 20 == 0:
print(f"Epoch : {epoch + 1} | Batch : {i+1} | D loss : {discriminator_loss} | G loss : {generator_loss}")
writer.add_scalar('discriminator loss', d_loss / 20, epoch * len(data_loader) + i)
writer.add_scalar('generator loss', g_loss / 20, epoch * len(data_loader) + i)
plot_image(grid[:10])
plt.show()
Thank you