NOT SOLVED: retain_graph=True problem

So, I am working with a generator and a discriminator in my network.
I wanted to add a new type of loss, which is computed given the features of every layer of the discriminator (these features are computed in the forward() of the discriminator), but contributes to the generator loss.
Let’s call this loss gen_feat_loss.

Therefore what I do is:

  1. call the forward of the discriminator, that returns a list of every discriminator layer output.
  2. compute the discriminator loss with the last element of the list returned by the forward.
  3. compute the gen_feat_loss.
  4. call the forward of the generator, compute all the other losses.
  5. call discriminator.zero_grad()
  6. discriminator_loss.backward()
  7. opt_discriminator.step()
  8. call generator.zero_grad()
  9. generator_loss = loss1 + loss2 + ... + gen_feat_loss
  10. generator_loss.zero_grad()
  11. generator_loss.backward()
  12. opt_generator.step()

However, when I run the code, I see the following error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I tried to change the order of the computations.
I tried to put the generator backward before, and the discriminator backward after.
I then tried to use retain_graph=True as the error suggests, but the RAM finishes.
Do you have any idea of what I am doing wrong?
It must be a problem of the backward of the generator, that tried to compute the gradients on the gen_feat_loss that were, for the last step, already computed by the discriminator.

When you inputted the fake image into the discriminator did you detach it by using

image.detach()

?

Yes, I used these lines of code:

fake_list_features, _ = self.run_discriminator(discriminator, fake.detach())
fake_labels = fake_list_features[-1]
fake_loss += self.compute_loss(fake_labels, self.zeros_like(fake_labels))

where run_discriminator calls the forward of the discriminator

Ok can you show the line where you define discriminator loss.

So I use three times the discriminator, because I am using a multi-scale implementation with 3 different scaled inputs. therefore I do:

for i in range(0,num_discriminators):
   if(i != 0):
       fake = self.downsample(fake)
       true = self.downsample(true)

   fake_list_features, _ = self.run_discriminator(discriminator, fake.detach())
   fake_labels = fake_list_features[-1]
   fake_loss += self.compute_loss(fake_labels, self.zeros_like(fake_labels)) 
   true_list_features, _ = self.run_discriminator(discriminator, true)
   true_labels = true_list_features[-1]
   true_loss += self.compute_loss(true_labels, self.zeros_like(true_labels))

discriminator_loss = true_loss + fake_loss

Ok how about the generator loss in the first post you only show part of it.

for the generator, firstly I compute the gen_feat_loss as I said with the features from the discriminator:
genFeatLoss += self.compute_genFeatLoss(fake_list_features, true_list_features, discriminator.n_layers, num_discriminators)

Then, after computing perceptual and other losses, I sum them like this in the train method:

for epoch in range(epochs):
      for i, batch in enumerate(self.train_loader): 
                ...
                # train discriminator
                self.opt_discriminator.zero_grad()
                discriminator_loss, genFeat_loss = self.compute_discriminator(generator, discriminator, 
                batch)
                discriminator_loss.backward()
                self.opt_discriminator.step()

                # train generator
                self.opt_generator.zero_grad()
                loss_1, loss_2, loss_3, _ = self.compute_generator_loss(generator, 
                discriminator, batch)

                generator_loss = loss_1 + loss_2 + loss_2 + genFeat_loss
                generator_loss.backward()
                self.opt_generator.step()

Ok I think in either the compute generator loss or the computer discriminator loss functions you have not detach the input images or detach the losses. Can you send those two functions too.

The function
compute_discriminator is the one I wrote before the loop for the training.

The function compute_generator_loss is independent, as the error occurs only when I add genFeat_loss into the sum
generator_loss = loss_1 + loss_2 + loss_2 + genFeat_loss
If I remove this from the sum, I don’t get any error.
The function self.compute_genFeatLoss is a simple mathematical formula

loss_G_GAN_Feat = 0
for j in range(len(fake_features)):
      loss_G_GAN_Feat += self.computeL1Loss(fake_features[j], true_features[j].detach())

ok try to detach fake_features too.

But why?
I want the generator to learn to imitate the true images.
If I use detach, in both, that means I am not computing the gradient on that loss, no?

Indeed, I tried and I don’t get the error anymore, but if I train I get increasing losses for the generator.

Yes that is true but it still should train. I don’t have any other ideas. Maybe you could that variable in the discriminator loss and that could help your generator.

Thanks for trying to help!
But I think it’s not a solution, as I need that loss to train the generator.
If I put .detach to everything, then I am not training with that loss… There must be another error somewhere

Ya sorry I couldn’t help I had the same problem with my gan a little bit ago and this is what I did and the gan trains fine. I looked it up everywhere but that was the only answer I could fine. Hope you find a better solution because that could help me too. Sorry again. Also you might want to ask this question again cuz people will see replies and won’t answer it.