Backward twice without retain_graph=true where I shouldn't


I’m trying to avoid retain_graph=true, but I can’t figure out where I’m doing an extra backward where I shouldn’t. Could someone please give me some guidance? Below is my code:

I’m porting a tensorflow meta-learning algorithm to PyTorch. It’s a variant of MAML so, there is an ‘inner loop’ of the algorithm where we do two backwards. This inner loop is handled by autograd.grad, and the outer loop wrapping is has a loss.backward(). From others’ implementation of MAML as well as many posts here about multiple backwards, I know that the outer loop should not need a retain_graph=true to work. But runtime error says I should, and I struggle to figure out where I got it wrong. Here’s my model: (what makes it different from MAML is an ‘inference network,’ which I implemented as a bunch of extra nn.module (it’s the self.encoder field in __init())
code I’m porting: GitHub - haebeom-lee/l2b: Tensorflow implementation of "Learning to Balance: Bayesian Meta-learning for Imbalanced and Out-of-distribution Tasks" (ICLR 2020 oral)

import torch
import numpy as np
from layers import *
from encoder import InferenceNetwork
from torch import autograd
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class LearningToBalance():
  def __init__(self, data_name : str, number_of_inner_gradient_steps : int,
               ways : int, shots : int, inner_learning_rate : float,
               outer_learning_rate : float, batch_size : int):
    if data_name == 'cifar':
      self.xdim = 32
      self.number_of_channels = 32
    elif data_name == 'mimgnet':
      self.xdim = 84
      self.number_of_channels = 32
      raise ValueError("Invalid dataset name")
    # training parameters
    self.number_of_inner_gradient_steps = number_of_inner_gradient_steps
    self.number_of_classes = ways
    # self.metabatch = metabatch
    self.inner_lr = inner_learning_rate
    self.outer_lr = outer_learning_rate
    self.batch_size = batch_size # number of TASKs in a batch!
    # balancing variables from InferenceNetwork
    self.encoder = InferenceNetwork(ways, shots, data_name, True, True, True).to(DEVICE)
    self.use_inner_step_size_vector = True # learnable, from Meta-SGD
    self.use_zeta = True
    self.use_gamma = True
    self.use_omega = True
    self.use_alpha = True # learn alpha, instead of using set value
    """model definition: 4 convolutional layers, followed, by 1 linear layer"""
    self.number_of_convolutional_layer = 4
    self.input_channel = 3
    self.kernel_size = 3
    self.number_of_hidden_channels = NUMBER_OF_HIDDEN_CHANNELS
    self.parameter = self.get_parameter_maml('theta') # 'this is the 'theta'
    self.alpha = self.get_parameter_maml('alpha')
    self.optimizer = torch.optim.Adam(params=list(self.parameter.values()) + list(self.alpha.values()), lr=self.outer_lr,)
    self.validation_interval = 50
  def get_parameter_maml(self, name : str) -> dict[str, torch.Tensor]:
    """initializes self.parameters"""
    # initializers
    if name == 'theta':
      init_convolution  = lambda x : torch.nn.init.xavier_uniform_(x)
      init_bias         = lambda x : torch.nn.init.zeros_(x)
      init_dense        = lambda x : torch.nn.init.xavier_uniform_(x)
    else: # name == 'alpha'
      init_convolution  = lambda x : torch.nn.init.constant_(x, 0.01)
      init_bias         = lambda x : torch.nn.init.constant_(x, 0.01)
      init_dense        = lambda x : torch.nn.init.constant_(x, 0.01)
    parameters = {}
    input_channel = self.input_channel
    # convolutional layers
    print(f"model, get_parameter, generating params {name}")
    for l in range(1, self.number_of_convolutional_layer + 1):
      parameters[f'convolution_{l}_weight'] = init_convolution(torch.empty(
        size=[self.number_of_hidden_channels, input_channel, self.kernel_size, self.kernel_size],
      parameters[f'convolution_{l}_bias'] = init_bias(torch.empty(
        # do NOT assign int to 'size' field! (type problem)
        # for weight dimension [64, 3, 3, 3], we expect bias of [64]<-one dimension only!
      input_channel = self.number_of_hidden_channels
    # linear layer
    parameters[f'dense_weight'] = init_dense(torch.empty(
      size=[self.number_of_classes, self.number_of_hidden_channels],
    parameters[f'dense_bias'] = init_bias(torch.empty(
    return parameters
  def forward_theta(self, x : torch.Tensor, theta : dict[str, torch.Tensor], name : str):
    x = torch.reshape(x, [-1, self.input_channel, self.xdim, self.xdim])
    for l in range(1, self.number_of_convolutional_layer + 1):
      if (x.isnan().any()):
      x = F.conv2d(x, theta[f'convolution_{l}_weight'], theta[f'convolution_{l}_bias'], stride=1, padding='same')
      if (x.isnan().any()):
      x = F.batch_norm(x, None, None, training=True)
      if (x.isnan().any()):
      x = F.relu(x)
    x = torch.mean(x, dim=[2, 3])
    x = DenseBlock_F(x, theta[f'dense_weight'], theta[f'dense_bias']) # just executes the calculation
    return x
  def zeta_update_initialization(self, 
    zeta : dict[str, torch.Tensor], theta : dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """update theta with zeta"""
    theta_update_by_zeta = {}
    if self.use_zeta:
      for layer_key in zeta.keys():
        if '_w' in layer_key:
          # θ_0 = θ + zeta
          theta_update_by_zeta[layer_key] = theta[layer_key] * (1. + zeta[layer_key])
        elif '_b' in layer_key:
          theta_update_by_zeta[layer_key] = theta[layer_key] + zeta[layer_key]
          assert("LearningToBalance::theta_update_initialization(): \n\tcheck ur dictionary keys!")
    return theta_update_by_zeta
  def omega_modify_loss_of_class(self, 
    omega: torch.Tensor, cross_entropy_per_class : torch.Tensor) -> torch.Tensor:
    if self.use_omega:
      inner_loss = torch.sum(cross_entropy_per_class * F.softmax(omega, -1))
      inner_loss = torch.sum(cross_entropy_per_class)
    return inner_loss
  def gamma_modeify_inner_step_learning_rate(self, 
    theta_tensor : torch.Tensor, gamma_tensor : torch.Tensor, delta : torch.Tensor) -> torch.Tensor:
    """change learning rate with gamma"""
    if self.use_gamma:
      return theta_tensor - delta * torch.exp(gamma_tensor)
      return theta_tensor - delta
  def _inner_loop(self, x_train : torch.Tensor, y_train : torch.Tensor, train : bool, 
    omega : torch.Tensor, gamma : dict[str, torch.Tensor]) -> tuple[dict[str, torch.Tensor], list[torch.Tensor]]:
    """return: updated parameters theta, list of inner accuracy (length is number_of_inner_gradient_steps + 1)"""
    accuracy_list = []
    y_train =
    param_cloned = {
            k: torch.clone(v)
            for k, v in self.parameter.items()
    for _ in range(self.number_of_inner_gradient_steps):
      if (x_train.isnan().any()):
        print(f"forward batch_norm: nan in put")
      inner_logits = self.forward_theta(x_train, param_cloned, name="inner_loop") # predictions
      cross_entropy_per_class = F.cross_entropy(inner_logits, y_train)
      inner_accuracy = Accuracy(inner_logits, y_train)
      """omega: modulate class-specific loss"""
      inner_loss = self.omega_modify_loss_of_class(omega, cross_entropy_per_class)
      # make computation graph, so 2nd-order derivativa can be calculated
      # when we call .backward() on OUTER loss
      # when train, DO make graph, to allow back propagation; when validate, NO make graph!
      grads = autograd.grad(
        outputs=inner_loss, inputs=param_cloned.values(), create_graph=train)
      gradient_dictionary : dict[str, torch.Tensor] = dict(zip(param_cloned.keys(), grads))
      """inner gradient step"""
      for layer_key in param_cloned.keys():
        if self.use_alpha: # manually checked that alpha is bound
          delta = self.alpha[layer_key] * gradient_dictionary[layer_key]
          delta = self.inner_lr * gradient_dictionary[layer_key]
        """use gamma: modulate task-specific learning rate"""
        param_cloned[layer_key] = self.gamma_modeify_inner_step_learning_rate(
          theta_tensor=param_cloned[layer_key], gamma_tensor=gamma[layer_key], delta=delta
    # no gradient update for this last logit and accuracy
    last_inner_logits = self.forward_theta(x_train, param_cloned, name="inner_loop last")
    last_inner_accuracy = Accuracy(last_inner_logits, y_train)
    return param_cloned, accuracy_list
  def _outer_step_single_task(self, input_task : tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], 
    train : bool) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]:
    """return: cross entropy, outer accuracy, KL, prediction, inner accuracy
    \nouter objective w.r.t. one task"""
    omega, gamma, zeta, KL = self.encoder(x_train, y_train, do_sample=train)
    Cardinality_train = x_train.shape[0]
    Cardinality_test = x_test.shape[0]
    # scale KL term w.r.t train and test set sizes
    KL /= (Cardinality_train + Cardinality_test)
    """zeta: modulate MAML initialization"""
    theta_update_by_zeta = self.zeta_update_initialization(zeta, self.parameter)
    """inner gradient steps; omega & gamma for task_specific modulation"""
    theta_update_by_inner_loop, inner_accuracy = self._inner_loop(x_train, y_train, train, omega, gamma)
    # self.parameter.update(theta_update_by_inner_loop)
    """outer-loss & test_accuracy"""
    logits_test = self.forward_theta(x_test, theta_update_by_inner_loop, name="outer loop")
    cross_entropy = CrossEntropy(logits_test,
    outer_accuracy = Accuracy(logits_test, y_test)
    prediction = F.softmax(logits_test, -1)
    return cross_entropy, outer_accuracy, KL, prediction, inner_accuracy
  def _outer_step_(self, input_task_batch : list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], train : bool
    """return: loss, result dictionary
    \nmade of _outer_step_single_task"""
    # print(f"model _outer_step")
    cross_entropy_list = []
    outer_accuracy_list : list[torch.Tensor]= []
    KL_list = []
    prediction_list = []
    inner_accuracy_list : list[list[torch.Tensor]]= []
    for t in range(self.batch_size):
      input_task = (input_task_batch[0][t], input_task_batch[1][t], 
                    input_task_batch[2][t], input_task_batch[3][t])
      cross_entropy, outer_accuracy, KL, prediction, inner_accuracy = self._outer_step_single_task(input_task, train)
    loss = torch.mean(torch.stack(cross_entropy_list))
    KL = torch.mean(torch.stack(KL_list))
    return loss, torch.mean(torch.stack(outer_accuracy_list)), KL, torch.mean(torch.Tensor(inner_accuracy_list), dim=0)
  def train(self, train_dataloader, valid_dataloader, ):
    """return: train loss, valid loss, train accuracy dictionary, valid accuracy dicitonary"""
    with autograd.detect_anomaly():
      print(f"train_dataloader length? {len(train_dataloader)}")
      print(f"valid_dataloader length? {len(valid_dataloader)}")
      # print(f"length of dataloader {len(train_dataloader)}")
      train_loss = []
      train_outer_accuracy = [] # outer means query; inner means support
      train_inner_accuracy_fresh = [] # pre-adapting
      train_inner_accuracy_adapt = [] # post-adapting
      """TODO: show valid loss s.t. it reflects when number of validation != number of training, or throw Error"""
      valid_loss = []
      valid_outer_accuracy = [] # outer means query; inner means support
      valid_inner_accuracy_fresh = [] # pre-adapting
      valid_inner_accuracy_adapt = [] # post-adapting
      # print(f"yo")
      for step, task_batch in enumerate(train_dataloader):
        """task_batch is a BATCH (list) of tasks, batch = batch_size,
        shape: [task batch, number of classes, width, height, number of channels]
        e.g. [5, 70 = way * (shot + query), 32, 32, 3 bc RGB]"""
        outer_loss, outer_accuracy, KL, inner_accuracy = self._outer_step_(task_batch, train=True)
        if step % self.validation_interval == 0: # it's validation time!
          for v_step, task_batch in enumerate(valid_dataloader):
            # print(f"valid step {v_step}")
            outer_loss, outer_accuracy, KL, inner_accuracy = self._outer_step_(task_batch, train=False)
      train_accuracy = {
        'outer':            train_outer_accuracy,
        'inner pre-adapt':  train_inner_accuracy_fresh,
        'inner post_adapt': train_inner_accuracy_adapt
      valid_accuracy = {
        'outer':            valid_outer_accuracy,
        'inner pre-adapt':  valid_inner_accuracy_fresh,
        'inner post_adapt': valid_inner_accuracy_adapt
      return train_loss, valid_loss, train_accuracy, valid_accuracy
  def test(self, test_dataloader):
    """return: test loss, test accuracy dictionary"""
    print(f"test_dataloader length? {len(test_dataloader)}")
    test_loss = []
    test_outer_accuracy = [] # outer means query; inner means support
    test_inner_accuracy_fresh = [] # pre-adapting
    test_inner_accuracy_adapt = [] # post-adapting
    for step, task_batch in enumerate(test_dataloader):
      # print(f"test step {step}")
      outer_loss, outer_accuracy, KL, inner_accuracy = self._outer_step_(task_batch, train=False)
    test_accuracy = {
      'outer':            test_outer_accuracy,
      'inner pre-adapt':  test_inner_accuracy_fresh,
      'inner post_adapt': test_inner_accuracy_adapt
    return test_loss, test_accuracy

Parameters are stored in a dictionary (4 convolution layers followed by one linear layer). They are updated manually. The ‘encoder’ network jumps in from time to time to calculate some variables (gamma, theta, omega) to join the calculation. My suspicion is that something’s up with the inference network (it basically generates summary statistics from the data–I’m confused by the tensorflow code so I left it as a learnable nn module). I’m missing something like autograd.detect_double_backward() (made-up function) that tells me exactly where I’m backwarding-when I shouldn’t, or some things to keep in mind when reading my code for bugs.

Thanks so much!

If you use torch.autograd.detect_anomaly (docs here Automatic differentiation package - torch.autograd — PyTorch 2.0 documentation)

with torch.autograd.detect_anomaly():
    a = torch.tensor(1., requires_grad=True)
    b = a.sin()

    c = b.sin()
    d = b.sin()


That would tell you in which part of the backward pass the error occurred, and what forward operation that corresponds to in your code.

/local/.tests/ UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
/local/pytorch3/torch/autograd/ UserWarning: Error detected in SinBackward0. Traceback of forward call that caused the error:
  File "/local/.tests/", line 1346, in <module>
    b = a.sin()
 (Triggered internally at /local/pytorch3/torch/csrc/autograd/python_anomaly_mode.cpp:118.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/local/.tests/", line 1352, in <module>
  File "/local/pytorch3/torch/", line 492, in backward
  File "/local/pytorch3/torch/autograd/", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
1 Like

Thanks very much for your help soulitzer! Could you please give some assistance in mapping the mistake from your example to the printout?

Here’s what I made out so far:
In the example, the error is that c.backward() doesn’t have retain_graph=true.
It looks like detect_anomaly() picked out two errors:
1 b = a.sin() # I’m not sure what’s wrong with this operation
2 torch.autograd.backward() # this isn’t in the code. Perhaps it’s the result from d.backward() from the above?

It’s pointing out that running a.sin() in the forward is what produced the part of the autograd graph that is backprop’d through twice during our two backwards.

It’s being backproped through twice because that variable was used to compute both the output c and the output d.

1 Like

Thank you so much! I think I can figure out where I got wrong!