Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)

I have the following codes

# Per essere più puliti definire solo la loss, poi il rischio empirico si calcola separatamente
from torch import nn

class AdversarialLoss(nn.Module):
    def __init__(self, discriminator):
      super(AdversarialLoss, self).__init__()
      self.discriminator = discriminator

    def forward(self, output):
      # calcolo le probabilita' per le diverse immagini upscaled
      return -torch.log(self.discriminator(output))
class VggLoss(nn.Module):
  def __init__(self, vgg19_model):
      super(VggLoss, self).__init__()
      self.vgg19_model = vgg19_model
      self.mse_loss = MSELoss()
  
  def get_feature_maps(self, module, input, output):
      self.feature_maps = output

  def forward(self, output, target):
      target_layer = self.vgg19_model.features[7]
        
      self.feature_maps = None
      hook_handle = target_layer.register_forward_hook(self.get_feature_maps)

      self.vgg19_model(output)
      vgg_feature_map_upscaled_img = self.feature_maps

      self.vgg19_model(target)
      vgg_feature_map_high_res_img = self.feature_maps

      hook_handle.remove()

      return self.mse_loss(vgg_feature_map_upscaled_img, vgg_feature_map_high_res_img)
# N.B. : il problema sembrerebbe essere tra discriminator loss e adversarial loss
class DiscriminatorLoss(nn.Module):

  def __init__(self, discriminator):
    super(DiscriminatorLoss, self).__init__()
    self.discriminator = discriminator

  def forward(self, output, target):
    probability_high_res_img = self.discriminator(target)
    probability_upscaled_img = 1 - self.discriminator(output)

    return torch.log(probability_high_res_img) + torch.log(probability_upscaled_img)
from torch.optim import Adam

def train_gan(num_of_iter, batch_size, device_type, data_loader, generator, discriminator, vgg19_model):
  optimizer_gen = Adam(generator.parameters(), lr=0.0001)
  optimizer_disc = Adam(discriminator.parameters(), lr=0.0001)
  iterator = iter(data_loader)

  # sposto i modelli sulla GPU. N.B. : se la memoria viene ecceduta capire quale dei modelli non mettere in memoria
  generator.to(device_type)
  discriminator.to(device_type)
  vgg19_model.to(device_type)

  vgg_loss = VggLoss(vgg19_model)
  adversarial_loss = AdversarialLoss(discriminator)
  discriminator_loss = DiscriminatorLoss(discriminator)

  for iteration in range(num_of_iter):
    # resetto i gradienti
    optimizer_gen.zero_grad()
    optimizer_disc.zero_grad()
    # prendo il prossimo batch
    batch, id = next(iterator)
    low_res_imgs = batch['low_res'].to(device_type)
    high_res_imgs = batch['high_res'].to(device_type)
    # genero le immagini ridimensionate
    upscaled_images = generator(low_res_imgs)
    # valuto la loss avversaria
    loss_adv = adversarial_loss(upscaled_images[0].unsqueeze(0))

    print("before adversarial loss ", iteration)

    for i in range(1, len(upscaled_images)):
      loss_adv += adversarial_loss(upscaled_images[i].unsqueeze(0))  

    print(loss_adv)

    print("after adversarial loss ", iteration)

    # valuto la loss vgg
    loss_vgg = vgg_loss(upscaled_images[0].unsqueeze(0), high_res_imgs[0].unsqueeze(0))

    print("before vgg loss ", iteration)

    for i in range(1, len(high_res_imgs)):
      loss_vgg += vgg_loss(upscaled_images[i].unsqueeze(0), high_res_imgs[i].unsqueeze(0))

    print("after vgg loss ", iteration)

    # calcolo la loss finale
    losses_sum = torch.div((0.006 * loss_vgg + 0.001 * loss_adv), batch_size)

    # calcolo il gradiente
    losses_sum.backward()
    #print(high_res_imgs[0].grad)
    # aggiorno i parametri del generatore
    optimizer_gen.step()

    print("print before loss disc ", iteration)

    loss_disc = discriminator_loss(upscaled_images[0].unsqueeze(0), high_res_imgs[0].unsqueeze(0))

    print("a")

    for i in range(1, len(upscaled_images)):
      loss_disc += discriminator_loss(upscaled_images[i].unsqueeze(0), high_res_imgs[i].unsqueeze(0))

    print(loss_disc)

    print("b")

    # meno in maniera tale che quando si discende il gradiente in realtà lo si massimizza
    loss_disc = torch.div(-1 * loss_disc, batch_size)

    print("c")

    # devo ascendere il gradiente 
    loss_disc.backward()

    print("d")
    # aggiorno i parametri del discriminatore 
    optimizer_disc.step() 

    print("e")

    print("print after loss disc ", iteration)

    if iteration % 3 == 0:
      print("at iteration ", iteration)
      transform = T.ToPILImage()
      upscaled_img = upscaled_images[0][0]
      high_res_img = high_res_imgs[0][0]
      img = transform(upscaled_img)
      img.show()
      img = transform(high_res_img)
      img.show()

When calling the methods with the following code

disc = Discriminator()
num_of_iter = 10000
batch_size = 16
train_gan(num_of_iter, batch_size, dev, data_loader, generator, disc, vgg19_model)

I get the following error : Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Not sure why it’s happening since I’m not calling backward on the same loss two times.

Since two losses loss_disc and losses_sum are both functions of upscaled_images, backward would backprop through that part of the graph twice.

Using retain_graph=True should be what you want to do here.

If you run into an error like “variable saved for backward has been modified in-place” that means you’ll need to move the first optimizer step to AFTER you do the second backward.

1 Like

I think the error you’re encountering is because you’re trying to backpropagate through the computation graph multiple times without setting the retain_graph=True flag. This happens because some of the tensors in the computation graph are shared between the different losses, so when you backpropagate for the first time, the graph is freed, and the second backpropagation attempt results in the error. You should set the retain_graph=True flag when calling backward() for the first time.

1 Like