RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation torch version 2.0.0

Here is the main code I ran, but I stuck on a problem

import os.path
import matplotlib.pyplot as plt
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable

from ddpg.noise import OrnsteinUhlenbeckActionNoise as OUNoise
from ddpg.replaybuffer import Buffer
from ddpg.actorcritic import Actor, Critic
from util import *

# hyperparameters
NUM_STATES = 7
NUM_ACTIONS = 5
NUM_EPISODES = 9000
NUM_STEPS = 300
ACTOR_LR = 0.0003
CRITIC_LR = 0.003
SIGMA = 0.2
BUFFER_SIZE = 100000
MINIBATCH_SIZE = 64
CHECKPOINT_DIR = './checkpoints/'
DISCOUNT = 0.9
EPSILON = 1.0
EPSILON_DECAY = 1e-6
TAU = 0.001
WARMUP = 70  # should be greater than the minibatch size

PLOT_FIG = False
SAVE_FIG = False
SAVE_TO_FILE = False


# convert a state variable [s1, s2, ..., s_n] to a state tensor
def observation_to_state(state_list):
    return torch.FloatTensor(state_list).view(1, -1)


class DDPG:
    def __init__(self, env, function_name):
        self.env = env
        self.function_name = function_name
        self.state_dim = NUM_STATES
        self.action_dim = NUM_ACTIONS
        self.actor = Actor(self.state_dim, self.action_dim)
        self.critic = Critic(self.state_dim, self.action_dim)
        self.target_actor = deepcopy(Actor(self.state_dim, self.action_dim))
        self.target_critic = deepcopy(Critic(self.state_dim, self.action_dim))
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
        self.critic_loss = nn.MSELoss()
        self.noise = OUNoise(mu=np.zeros(self.action_dim), sigma=SIGMA)
        self.replay_buffer = Buffer(BUFFER_SIZE)
        self.batch_size = MINIBATCH_SIZE
        self.checkpoint_dir = CHECKPOINT_DIR
        self.discount = DISCOUNT
        self.warmup = WARMUP
        self.epsilon = EPSILON
        self.epsilon_decay = EPSILON_DECAY
        self.reward_graph = []
        self.smoothed_reward_graph = []
        self.start = 0
        self.end = NUM_EPISODES
        self.tau = TAU

    # calculate target Q-value as reward and bootstrapped Q-value of next state via the target actor and target critic
    # inputs: Batch of next states, rewards and terminal flags of size self.batch_size
    # output: Batch of Q-value targets
    def get_q_target(self, next_state_batch, reward_batch, terminal_batch):
        target_batch = torch.FloatTensor(reward_batch)
        non_final_mask = torch.ByteTensor(tuple(map(lambda s: True if not s else False, terminal_batch)))
        next_state_batch = torch.cat(next_state_batch)
        next_action_batch = self.target_actor(next_state_batch)
        q_next = self.target_critic(next_state_batch, next_action_batch)

        non_final_mask = self.discount * non_final_mask.type(torch.FloatTensor)
        target_batch += non_final_mask * q_next.squeeze().data

        return Variable(target_batch).view(-1, 1)

    # weighted average update of the target network and original network
    # Inputs: target actor(critic) and original actor(critic)
    def update_targets(self, target, original):
        for target_param, original_param in zip(target.parameters(), original.parameters()):
            target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * original_param.data)

    # get the action that returns the maximum Q-value
    # inputs: Current state of the episode
    # output: the action which maximizes the Q-value of the current state-action pair
    def get_max_action(self, curr_state):
        noise = self.epsilon * Variable(torch.FloatTensor(self.noise()))
        action = self.actor(curr_state)
        action_with_noise = action + noise

        # get the action with max value
        action_list = action_with_noise.tolist()[0]
        max_action = max(action_list)
        max_index = action_list.index(max_action)

        return max_index, action_with_noise

    # training of the original and target actor-critic networks
    def train(self):
        # create the checkpoint directory if not created
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        print('Training started...')

        for episode in range(self.start, self.end):
            print('Episode #' + str(episode) + ':')
            state_list = self.env.reset(self.function_name)
            episode_reward = 0
            for step in range(NUM_STEPS):
                print('Step #' + str(step) + ':')

                # print current state
                print_state(state_list)

                # get max action
                curr_state_tensor = torch.Tensor(observation_to_state(state_list))
                self.actor.eval()
                action_idx, action_to_buffer = self.get_max_action(curr_state_tensor)
                action = {
                    'vertical': 0,
                    'horizontal': 0,
                    'scale_to': -1
                }
                if action_idx == 0:
                    # do nothing
                    pass
                elif action_idx == 1:
                    # scaling out
                    action['horizontal'] = HORIZONTAL_SCALING_STEP
                elif action_idx == 2:
                    # scaling in
                    action['horizontal'] = -HORIZONTAL_SCALING_STEP
                elif action_idx == 3:
                    # scaling up
                    action['vertical'] = VERTICAL_SCALING_STEP
                elif action_idx == 4:
                    # scaling down
                    action['vertical'] = -VERTICAL_SCALING_STEP

                # print action
                print_action(action)

                self.actor.train()

                # step
                state_list, reward, done = self.env.step(self.function_name, action)

                # print reward
                print('Reward:', reward)

                next_state_tensor = Variable(observation_to_state(state_list))
                episode_reward += reward

                # update the replay buffer
                self.replay_buffer.append((curr_state_tensor, action_to_buffer, next_state_tensor, reward, done))

                # training loop
                if len(self.replay_buffer) >= self.warmup:
                    curr_state_batch, action_batch, next_state_batch, reward_batch, terminal_batch = \
                        self.replay_buffer.sample_batch(self.batch_size)
                    curr_state_batch = torch.cat(curr_state_batch)
                    action_batch = torch.cat(action_batch)

                    q_prediction_batch = self.critic(curr_state_batch, action_batch)
                    q_target_batch = self.get_q_target(next_state_batch, reward_batch, terminal_batch)

                    # critic update
                    self.critic_optimizer.zero_grad()
                    critic_loss = self.critic_loss(q_prediction_batch, q_target_batch)
                    print('Critic loss: {}'.format(critic_loss))
                    critic_loss.backward(retain_graph=True)
                    self.critic_optimizer.step()

                    # actor update
                    self.actor_optimizer.zero_grad()
                    actor_loss = -torch.mean(self.critic(curr_state_batch, self.actor(curr_state_batch)))
                    print('Actor loss: {}'.format(actor_loss))
                    actor_loss.backward(retain_graph=True)
                    self.actor_optimizer.step()

                    # update targets
                    self.update_targets(self.target_actor, self.actor)
                    self.update_targets(self.target_critic, self.critic)
                    self.epsilon -= self.epsilon_decay
                # end of current step
            # end of current episode
            print('EP #' + str(episode) + ': total reward =', episode_reward)
            self.reward_graph.append(episode_reward)
            self.smoothed_reward_graph.append(np.mean(self.reward_graph[-10:]))

            # save to checkpoint
            if episode % 20 == 0:
                self.save_checkpoint(episode)

            # plot the reward graph
            if PLOT_FIG:
                if episode % 1000 == 0 and episode != 0:
                    plt.plot(self.reward_graph, color='darkorange')
                    plt.plot(self.smoothed_reward_graph, color='b')
                    plt.xlabel('Episodes')
                    if SAVE_FIG:
                        plt.savefig('ep' + str(episode) + '.png')
        # end of all episodes
        if PLOT_FIG:
            plt.plot(self.reward_graph, color='darkorange')
            plt.plot(self.smoothed_reward_graph, color='b')
            plt.xlabel('Episodes')
            if SAVE_FIG:
                plt.savefig('final.png')

        if SAVE_TO_FILE:
            # write rewards to file
            file = open("ddpg_smoothed_rewards.txt", "w")
            for reward in self.smoothed_reward_graph:
                file.write(str(reward) + "\n")
            file.close()
            file = open("ddpg_episode_rewards.txt", "w")
            for reward in self.reward_graph:
                file.write(str(reward) + "\n")
            file.close()

    # save checkpoints to file
    def save_checkpoint(self, episode_num):
        checkpoint_name = self.checkpoint_dir + 'ddpg-ep{}.pth.tar'.format(episode_num)
        checkpoint = {
            'episode': episode_num,
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'target_actor': self.target_actor.state_dict(),
            'target_critic': self.target_critic.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict(),
            'replay_buffer': self.replay_buffer,
            'reward_graph': self.reward_graph,
            'epsilon': self.epsilon
        }

        torch.save(checkpoint, checkpoint_name)

    # load checkpoints from file
    def load_checkpoint(self, checkpoint_file_name):
        if os.path.isfile(checkpoint_file_name):
            print('Loading checkpoint...')
            checkpoint = torch.load(checkpoint_file_name)
            self.start = checkpoint['episode'] + 1
            self.actor.load_state_dict(checkpoint['actor'])
            self.critic.load_state_dict(checkpoint['critic'])
            self.target_actor.load_state_dict(checkpoint['target_actor'])
            self.target_critic.load_state_dict(checkpoint['target_critic'])
            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
            self.replay_buffer = checkpoint['replay_buffer']
            self.reward_graph = checkpoint['reward_graph']
            self.epsilon = checkpoint['epsilon']
            print('Checkpoint successfully loaded')
        else:
            raise OSError('Checkpoint not found!')

the error I get is as follows:

C:\Users\12114\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\autograd\__init__.py:200: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\main.py", line 86, in <module>
    main()
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\main.py", line 82, in main
    agent.train()
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\ddpg\ddpg.py", line 126, in train
    action_idx, action_to_buffer = self.get_max_action(curr_state_tensor)
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\ddpg\ddpg.py", line 95, in get_max_action
    action = self.actor(curr_state)
  File "C:\Users\12114\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\ddpg\actorcritic.py", line 48, in forward
    h3 = self.fc3(h2_norm)
  File "C:\Users\12114\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\12114\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\main.py", line 86, in <module>
    main()
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\main.py", line 82, in main
    agent.train()
  File "c:\Users\12114\work\code\aware-main\aware-main\testing\ddpg\ddpg.py", line 179, in train
    critic_loss.backward(retain_graph=True)
  File "C:\Users\12114\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "C:\Users\12114\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\autograd\__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [40, 5]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Hi HJ!

You can frequently “fix” inplace-modification errors with pytorch’s allow-mutation context
manager. But this is usually just a work-around and often a cop-out.

This post is a discussion of how to find and properly fix inplace-modification errors.

This is likely the cause of your problem. critic_optimizer.step() performs inplace
modifications of critics’s parameters. But because you called .backward() with retain_graph = True`, those modified parameters are likely getting used again when
you backpropagate again through the old graph that you retained.

Best.

K. Frank