PyTorch RuntimeError: Trying to backward through the graph a second time

I am attempting to implement the Proximal Policy Optimization (PPO) algorithm using PyTorch. My implementation is based on my understanding of how the PPO algorithm works. At a high level, the process of my implementation looks something like this:

  1. First, we need to initialize the Policy Network with random parameters.
  2. Then, at each time step of the current episode, we need to compute the advantage function. This involves computing the Value functions V(s) and Q(s,a) in order to get A(s,a) = V(s) - Q(s,a). In my implementation, each value function is a function approximator that optimizes both of them at each time step using a simple gradient descent approach.
  3. Additionally, at each time step of the episode, I compute the clipped surrogate values at the current step.
  4. At the end of the episode, I calculate the expected value of all the clipped surrogate values in order to compute the SGA (Stochastic Gradient Ascent) for optimizing the Policy Network parameters.

My Value functions and Policy Network are implemented as classes. The Value functions classes each contain PyTorch initialization and forward pass methods, as well as MSE loss and SGD methods. The Policy Network class includes initialization and forward pass methods, along with the clipped surrogate value method for computation at each time step, and lastly the SGA method. You can find the full code at https://github.com/BernardoOlisan/PPO-Clip/blob/main/ppo.py

Everything is working, but when I try to compute SGA for the policy network at the end of the episode, this error shows up:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Some people suggested that the issue might be due to my use of loss.backward() in each class (value functions and policy network), and advised me to resolve it by setting retain_graph=True in each loss.backward() instance. However, this solution is not effective.

I am unsure what is causing the error in my code. Why is the Policy Network SGA throwing this error? How can I fix it? Am I making a mistake?

Here is my policy rollouts:

def rollouts(self, policy, state_value_function, action_value_function):
    old_policy = copy.deepcopy(policy)

    for episode in range(self.EPISODES):
        print(f"###################################### EPISODE [{episode}] ######################################")

        (initial_state, _) = self.env.reset()
        self.env.render()
        state = initial_state
        clipped_surrogated_values = []

        for step in range(self.EPISODE_STEPS):
            print(f"---------------------- STEP [{step}] ----------------------")

            policy_output = policy.forwardPropagation(torch.from_numpy(state))
            old_policy_output = old_policy.forwardPropagation(torch.from_numpy(state))
            action = torch.distributions.Categorical(policy_output).sample().item()
            old_action = torch.distributions.Categorical(old_policy_output).sample().item()

            state, reward, terminated, truncated, info = self.env.step(action)

            target_value = torch.tensor([reward], dtype=torch.float32)

            predicted_state_value = state_value_function.forwardPropagation(torch.from_numpy(state))
            loss_state_value = state_value_function.MSE(predicted_state_value, target_value)
            state_value_function.SGD(loss_state_value, VALUE_LEARNING_RATE)

            predicted_action_value = action_value_function.forwardPropagation(torch.from_numpy(state), torch.tensor([action], dtype=torch.float32))
            loss_action_value = action_value_function.MSE(predicted_action_value, target_value)
            action_value_function.SGD(loss_action_value, VALUE_LEARNING_RATE)

            advantage = predicted_state_value - predicted_action_value
            ratio = torch.tensor((action + 1e-6) / (old_action + 1e-6), dtype=torch.float32)
            clipped_surrogate_objective = policy.clippedSurrogateObjective(ratio, advantage)
            clipped_surrogated_values.append(clipped_surrogate_objective)

            old_policy = copy.deepcopy(policy)

            # Policy Network Rollout
            print("Policy Network (π) Rollout:")
            print(f"Policy NN output: {policy_output}")
            print(f"Old Policy NN output: {old_policy_output}")
            print(f"Selected action: {action}")
            print(f"Reward received: {reward}\n")

            # State Value Function (V(s))
            print("State Value Function (V(s)):")
            print(f"Predicted Value: {predicted_state_value}")
            print(f"Loss (V(s)): {loss_state_value}\n")

            # Action Value Function (Q(s, a))
            print("Action Value Function (Q(s, a)):")
            print(f"Predicted Value: {predicted_action_value}")
            print(f"Loss (Q(s, a)): {loss_action_value}\n")

            # Clipped Surrogate Objective at t (J(θ_t))
            print("Clipped Surrogate Objective at t (J(θ_t)):")
            print(f"Ratio: {ratio}")
            print(f"Advantage: {advantage}")
            print(f"Result: {clipped_surrogate_objective}\n")

            if terminated:
                time.sleep(1)
                break
            time.sleep(0.1)

        expected_surrogate_objective = sum(clipped_surrogated_values)
        # Here is where the error shows up
        policy.SGA(expected_surrogate_objective, NN_LEARNING_RATE) 

        print(f"\nEPISODE DETAILS:")
        print(f"Expected Surrogate Objective: {expected_surrogate_objective}")

    self.env.close()

Policy Network class is this one:

class PolicyNetwork(nn.Module):
def init(self, input_dim, hidden_dim, output_dim):
super(PolicyNetwork, self).init()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)

def forwardPropagation(self, X):
    Y = torch.relu(self.fc1(X))
    Z = self.fc2(Y)
    z_hat = F.softmax(Z, dim=-1)
    return z_hat

def clippedSurrogateObjective(self, ratio, advantage):
    first_value = ratio * advantage
    second_value = torch.clamp(ratio, 1 - EPSILON, 1 + EPSILON) * advantage
    return torch.min(first_value, second_value)

def SGA(self, expected_surrogate_objective, learning_rate):
    ''' Error in here...
        RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after 
        they have already been freed). Saved intermediate values of the graph are freed when you call .backward() 
        or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if 
        you need to access saved tensors after calling backward.
    '''
    optimizer = optim.Adam(self.parameters(), lr=learning_rate)
    optimizer.zero_grad()

    loss = -expected_surrogate_objective
    loss.backward()
    optimizer.step()

The complete code can be found at https://github.com/BernardoOlisan/PPO-Clip/blob/main/ppo.py. It’s not a very large code.

Thank you in advance.

If you have multiple backward passes using connected computation graphs, the claim is correct and the second backward pass will fail since intermediates were already freed in the first call.
Use retain_grah=True in the first backward pass or accumulate the losses before calling backward a single time. I also don’t know what “this solution is not effective” means.

Apologies, using retain_graph=True is not effective in solving the issue because I encounter an error when setting retain_graph=True in all instances of loss.backward()

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 1]], which is output 0 of AsStridedBackward0, is at version 20; expected version 19 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Also, they are 3 different Neural Networks that each of them has a loss.backward()

any idea @ptrblck ?

Hi Bernardo!

I assume that the outputs of your Value functions depend somehow on
your Policy Network and that you need to take these dependencies into
account when you optimize your Policy Network. (If not, you could get
rid of the retain_graph =True and .detach() the outputs of your Value
functions before you use them in anything that has to do with your Policy
Network. Doing so would solve the “backward a second time” error you
reported in your first post.)

The problem is (if you leave the dependency of your Policy Network on
your Value functions in place) that when you call .backward() on your
Policy-Network loss, you will also backpropagate through those Value
functions. But you have already done something along the lines of calling
optimizer.step() on your Value functions and doing so modifies the
Parameters of your Value functions inplace, leading to the error you see.

Assuming that this is what is going on, you should modify the forward
passes of your Value functions so that their Parameters (that are being
modified inplace) are not, themselves, needed in Policy Network’s
backward pass. You can probably achieve this with something like:

# in ValueFunction
linA = torch.nn.Linear (in_features, out_features)
...
# instead of
# y = linA (x)
# use
y = torch.nn.functional.linear (x, linA.weight.clone(), linA.bias.clone())

The point is that when backpropagating through ValueFunction, the
clone()s of linA’s Parameters are used and those clone()s haven’t
been modified by calling optimizer.step().

For further insight into what is going on in this situation I described or for
some suggestions on how to debug your issue if it’s something else that
is going on, please see this post:

Best.

K. Frank

That’s interesting because there’s actually a relationship between the Value Functions networks and the Policy Network. The Policy Network uses an objective function as the loss (clipped surrogate objective function from PPO algorithm), this objective function uses an Advantage Function in order to be computed. The Advantage Function is calculated by the difference between the two Value Functions Networks outputs, so in theory that is the relation or connection that Value Networks and Policy Network have.

Passing the output of the Value Networks directly to compute the Advantage function, which serves to calculate the Policy Network objective, also passes the gradients grad_fn that the Value Networks produced. This issue can be resolved by using .detach().

I didn’t think the advantage function was causing the issue, as the connection between Value Networks and Policy Networks was difficult to see.

Thanks, K. Frank