Error when updating policy networks: Trying to backward through the graph a second time

Hello, I am using Pytorch to implement a transformer encoder based on GTrXL and train it with reinforcement learning. I am able to successfully complete a single forward and backward pass, but after the second forward pass when I go to backward again, I run into this error:

terminate called after throwing an instance of 'c10::Error'
  what():  Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Exception raised from unpack at ../torch/csrc/autograd/saved_variable.cpp:134 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xb0 (0x739c9cda57f0 in /opt/libtorch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x84 (0x739c9cd4e092 in /opt/libtorch/lib/libc10.so)
frame #2: torch::autograd::SavedVariable::unpack(std::shared_ptr<torch::autograd::Node>) const + 0x1392 (0x739c82691d92 in /opt/libtorch/lib/libtorch_cpu.so)
frame #3: torch::autograd::generated::MmBackward0::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0xc5 (0x739c817ebb25 in /opt/libtorch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x584d2eb (0x739c8264d2eb in /opt/libtorch/lib/libtorch_cpu.so)
frame #5: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x14c2 (0x739c826478c2 in /opt/libtorch/lib/libtorch_cpu.so)
frame #6: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x650 (0x739c82648510 in /opt/libtorch/lib/libtorch_cpu.so)
frame #7: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x194 (0x739c82642134 in /opt/libtorch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xe1c34 (0x739c1f0e1c34 in /usr/lib/libstdc++.so.6)
frame #9: <unknown function> + 0x9439d (0x739c9a78639d in /usr/lib/libc.so.6)
frame #10: <unknown function> + 0x11949c (0x739c9a80b49c in /usr/lib/libc.so.6)

And here is the code for the backward pass:

void Agent::update_policy(float tau, float gamma) {
    // Sample a batch of experiences from memory.
    auto [states, actions, rewards, next_states, terms] = this->mem.sample(32);
    
    // Determine the target q-values.
    torch::Tensor next_actions = std::get<0>(this->proto(next_states));
    torch::Tensor next_q_vals = this->critic_tgt.forward(next_states, next_actions);
    torch::Tensor target_q_vals = rewards.unsqueeze(-1) + gamma * (1.0 - terms) * next_q_vals;
    
    // Update critic networks.
    this->critic.zero_grad();
    torch::Tensor q_vals = this->critic.forward(states, actions);
    torch::Tensor value_loss = torch::nn::functional::mse_loss(q_vals, target_q_vals);

    // The GTrXL encoder is shared between two linear heads (one for the actor
    // and one for the critic). So, we backward the loss and step the optimizers
    // for both the critic and the GTrXL encoder.
    value_loss.backward();
    this->opt_critic.step();
    this->opt_tfxl.step();

    // Update actor networks.
    this->actor.zero_grad();
    auto [_, _, policy_loss] = this->forward(states);
    policy_loss = -policy_loss.mean();

    // Just like before, since the GTrXL has two heads, we backward the loss,
    // then step the optimizers for both the actor and the GTrXL encoder. This
    // way, the GTrXL encoder has been updated with the gradients from both the
    // actor and critic networks and will eventually be able to provide outputs
    // that are optimal for either head.
    policy_loss.backward();
    this->opt_actor.step();
    this->opt_tfxl.step();

    // Update target networks.
    track(this->actor_tgt, this->actor, tau);
    track(this->critic_tgt, this->critic, tau);
}

Based on my research into this error, it looks like this usually occurs because some tensor is getting stored and used for the backward pass again. So, my question is this:

  1. If the error is here in the update_policy function, what is causing it?
  2. If the error is somewhere else, how can I track down the tensor that is causing problems?

I know that the error message suggests setting retain_graph=true. However, I have looked up this behavior and confirmed that I do not want to do this.

Bumping this topic. If anyone has suggestions that would be much appreciated.

it looks like this usually occurs because some tensor is getting stored and used for the backward pass again

Yes, if you are backwarding multiple times over the same graph, and use the same

However, I have looked up this behavior and confirmed that I do not want to do this.

It sounds like want to make sure that the two backwards operate on two disjoint graphs?

How can I track down the tensor that is causing problems?

You can use TORCH_LOGS=“+autograd” to log the backward pass, or enable
Automatic differentiation package - torch.autograd — PyTorch 2.5 documentation, which can give you a stack trace of where in the forward the error from backward corresponds to.