RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 4096]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace furth

I am trying to implement an agent using proximal policy optimization. However, I am getting

C:\Users\Asus\anaconda3\lib\site-packages\torch\autograd\__init__.py:251: UserWarning: Error detected in AddmmBackward0. No forward pass information available. Enable detect anomaly during forward pass for more information. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\python_anomaly_mode.cpp:97.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "C:\Users\Asus\Desktop\RL_a\game.py", line 289, in <module>
    agent.update_policy(states, actions, rewards, log_probs, values, next_values, dones)
  File "C:\Users\Asus\Desktop\RL_a\ppox.py", line 95, in update_policy
    loss.backward(retain_graph=True)
  File "C:\Users\Asus\anaconda3\lib\site-packages\torch\_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "C:\Users\Asus\anaconda3\lib\site-packages\torch\autograd\__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 4096]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

My code is

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np

# Define the neural network for the policy
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, output_size)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

# Define the Proximal Policy Optimization agent
class PPOAgent:
    def __init__(self, input_size, output_size, lr=1e-3, gamma=0.99, epsilon=0.2, value_coef=0.5, entropy_coef=0.01):
        self.policy = PolicyNetwork(input_size, output_size)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef

    def select_action(self, state):
        #print("state: ", state)
        xstate = torch.from_numpy(state).float()
        probs = self.policy(xstate)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

    def update_policy(self, states, actions, rewards, log_probs, values, next_values, dones):
        torch.autograd.set_detect_anomaly(True)
        returns = self.compute_returns(rewards, dones)
        """
        print("values", values, "next_values", next_values)
        print("size of values", len(values[0]), "size of next_values", len(next_values))
        print("type:",type(values[0]), type(values))
        print("returns", returns)
        print("size of returns", len(returns))
        print("type:", type(returns))
"""
        #convert values to tensor
        
        xvalues = torch.tensor(values, requires_grad=True).float()
        advantages = returns - xvalues
        print("advantages: ", advantages)
        

        for _ in range(ppo_epochs):
            for i in range(len(states)):
                state = torch.from_numpy(states[i]).float()
                action = torch.tensor(actions[i])
                old_log_prob = log_probs[i]
                value = xvalues[i]
                next_value = next_values[i]
                advantage = advantages[i]
                return_ = returns[i]

                # Compute the new log probability and value
                new_probs = self.policy(state)
                
                
                new_log_prob = torch.log(new_probs[action].clone())
                new_value = self.get_value(states[i])

                # Compute the surrogate loss
                ratio = torch.exp(new_log_prob - old_log_prob)
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantage
                policy_loss = -torch.min(surr1, surr2).mean()

                # Compute the value loss
                value_loss = F.mse_loss(new_value, return_)
                #print("lennnn: ",len(new_value), len(return_), len(advantage), len(ratio), len(surr1), len(surr2), len(policy_loss), len(value_loss), len(new_probs), len(new_log_prob), len(old_log_prob), len(advantage), len(advantages), len(returns), len(xvalues), len(value), len(next_value), len(action), len(state), len(states[i]), len(states), len(actions), len(log_probs), len(values), len(next_values), len(dones))

                

                # Compute the entropy loss
                entropy_loss = -torch.sum(new_probs * torch.log(new_probs + 1e-10))

                # Total loss
                loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy_loss

                # Optimize the policy
                self.optimizer.zero_grad()
                print("loss: ", loss)
                with torch.autograd.detect_anomaly():
                    loss.backward(retain_graph=True)
                self.optimizer.step()

    def compute_returns(self, rewards, dones):
        returns = []
        R = 0

        for reward, done in zip(reversed(rewards), reversed(dones)):
            if done:
                R = 0
            Rz = reward + self.gamma * R
            returns.insert(0, Rz)

        returnsx = torch.tensor(returns).float()
        returnsy = (returnsx - returnsx.mean()) / (returnsx.std() + 1e-8)

        return returnsy

    def get_value(self, state):
        print("ggstate: ", state)
        statex = torch.from_numpy(state).float()
        return self.policy(statex)

# Set your environment parameters
input_size = 64  # Assuming a flat representation of the chess board as input
output_size = 64*64 # Number of legal moves in your chess environment

# Initialize the PPO agent
agent = PPOAgent(input_size, output_size)

# Training loop
num_episodes = 100
ppo_epochs = 4
"""
for episode in range(num_episodes):
    state = 0
    done = False

    states, actions, rewards, log_probs, values, next_values, dones = [], [], [], [], [], [], []

    while not done:
        action, log_prob = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        log_probs.append(log_prob)
        values.append(agent.get_value(state))
        next_values.append(agent.get_value(next_state))
        dones.append(done)

        state = next_state

    agent.update_policy(states, actions, rewards, log_probs, values, next_values, dones)
"""

I tried to remove all in place operations but still get the same error.

Could you explain why retain_graph=True is used as it’s often added as a workaround hiding the real issue?

I was getting the error

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I found that retain_graph=True is used to resolve this issue. THen, I added it

That’s not the case as it’s just creating this new issue.
Try to isolate the original issue and why you are calling backward multiple times.
Usually this is caused by appending the computation graph in each iteration.

I tried but couldn’t find how to arrange the code to avoid this issue actually. How can I use backward without calling multiple times?