RuntimeError: Trying to backward through the graph a second time, mse_loss

Hello.

I am trying to compute a custom loss function, but I can’t seem to be able to train the network, even if I call loss.backward(retain_graph=True)

Here is the training code:

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, image in enumerate(train_loader):
      optimizer.zero_grad()
      data = image[0].to(device)
      output = model(data).to(device)
      #print(VGG_model(output).shape)
      data_loss = VGG_model(data)
      output_loss = VGG_model(output)
      content_loss = F.mse_loss(data_loss[-1], output_loss[-1])
      style_loss = 0
      style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[0]), torch.mean(output_loss[0])) + F.mse_loss(torch.std(img_stil_loss[0]), torch.std(output_loss[0])))
      style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[1]), torch.mean(output_loss[1])) + F.mse_loss(torch.std(img_stil_loss[1]), torch.std(output_loss[1])))
      style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[2]), torch.mean(output_loss[2])) + F.mse_loss(torch.std(img_stil_loss[2]), torch.std(output_loss[2])))
      style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[3]), torch.mean(output_loss[3])) + F.mse_loss(torch.std(img_stil_loss[3]), torch.std(output_loss[3])))
      loss = content_loss + 0.5 * style_loss
      loss.backward(retain_graph=True)
      optimizer.step()