RuntimeError with in-place operation while training DDPG

Hi,
Our project is to use meta to assist with ddpg training. Both meta.py and ddpg.py have the class Actor and the class Critic.
The class Actor of ddpg and meta is the same content but in different files. One is in ddpg.py, another one is in meta.py. So does the Critic.
When ddpg training and meta training are operated independently, the execution results are normal.
But when both files are executed at the same time, the following error occurs.
This is the error message that I get after I add the torch.autograd.set_detect_anomaly(True).

[W ..\torch\csrc\autograd\python_anomaly_mode.cpp:104] Warning: Error detected in MmBackward0. Traceback of forward call that caused the error:
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\main2.py", line 843, in <module>
    train(on_group = ON_GROUP) #train(scenario_list=train_set)
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\main2.py", line 222, in train
    meta_action = meta_rl.choose_action(meta_state, meta_state_mcs)
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\META_ori_local.py", line 251, in choose_action
    action = self.actor(self.state_concatenate)#.cpu().numpy()  # Forward pass through the actor network
  File "C:\Users\$STG000-RQUJF6OTH79G\.conda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\META_ori_local.py", line 78, in forward
    actions = torch.tanh(self.fc3(x_clone)) # TensorFlow: a = tf.keras.layers.Dense(self.a_dim, activation=tf.nn.tanh, name='a', trainable=trainable)(net2)
  File "C:\Users\$STG000-RQUJF6OTH79G\.conda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\$STG000-RQUJF6OTH79G\.conda\envs\pytorch\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\$STG000-RQUJF6OTH79G\.conda\envs\pytorch\lib\site-packages\torch\nn\functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
 (function _print_stack)
Traceback (most recent call last):
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\main2.py", line 843, in <module>
    train(on_group = ON_GROUP) #train(scenario_list=train_set)
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\main2.py", line 258, in train
    rl.learn(t_slot)
  File "C:\Users\$STG000-RQUJF6OTH79G\Desktop\meta2-pytorch\meta_v9\Attention_DDPG_ori_local.py", line 301, in learn
    critic_loss.backward()
  File "C:\Users\$STG000-RQUJF6OTH79G\.conda\envs\pytorch\lib\site-packages\torch\_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\$STG000-RQUJF6OTH79G\.conda\envs\pytorch\lib\site-packages\torch\autograd\__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 3]], which is output 0 of AsStridedBackward0, is at version 13; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I’m not sure where the in-place problem occurs, but I think the actor is causing the error.
So, below is my code of class Actor.

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_bound, dropout_rate):
        super(Actor, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Actor network architecture
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        self.action_bound = action_bound
        self.dropout = nn.Dropout(p = dropout_rate)
        
    def forward(self, state):        
        state = state.to(self.device)
        # Forward pass of the actor network
        x1 = torch.relu(self.fc1(state))
        x1_drop = self.dropout(x1)
        x2 = torch.relu(self.fc2(x1_drop)
        x2_drop = self.dropout(x2)
        actions = torch.tanh(self.fc3(x2_drop))
        
        return actions * self.action_bound

This problem has bothered me for several days, I hope it can be solved.

Best regards,
Ning

I don’t think thee Actor itself is causing the issue as it should then also fail in isolation. I would however speculate your training code and the interaction between both models as well as the section updating parameters etc. might cause the error.
You could take a look at e.g. this topic to understand when stale forward activations could cause issues.

Hello @ptrblck , this post seems to solve the problem related to retain_graph=True. However, in my learn function, backward does not have retain_graph=True, but loss.backward(), which means retain_graph is the default value False.
Cause the design of ddpg does not use retain_graph=True, this is why I did not use it. This is the code that I reference.

And the following is my learn function:

def learn(self, t_slot):
        if not self.is_training: # for not onTrain
            return  # Exit the function if not training
    
        critic_grad_norm = 0.0 # for plot gradient
        actor_grad_norm = 0.0 # for plot gradient

        memory_capacity = self.MEMORY_CAPACITY
        batch_size = self.BATCH_SIZE

        # Generate random indices for sampling a batch from the value of memory_capacity amount
        indices = np.random.choice(memory_capacity, batch_size)

        # Extract a batch of experiences from the replay memory buffer using random `indices`
        batch = self.memory[indices, :]

        # Extract the current states from the batch
        batch_states = batch[:, :self.s_dim]
        # Extract the actions obtained after taking the current states from the batch
        batch_actions = batch[:, self.s_dim: self.s_dim + self.a_dim]
        # Extract the rewards obtained after taking the actions from the batch
        batch_rewards = batch[:, -self.s_dim - 1: -self.s_dim]
        # Extract the next states observed after taking the rewards from the batch
        batch_next_states = batch[:, -self.s_dim:]
        
        # Compute target Q value
        target_action = self.actor_target(batch_next_states)
        target_q_value = self.critic_target(batch_next_states, target_action)
        target_q_value = batch_rewards + self.GAMMA * target_q_value
        
        # Update critic
        current_q_value = self.critic(batch_states, batch_actions)
        critic_loss = F.mse_loss(current_q_value, target_q_value) # td_error
        critic_loss.backward()
        self.critic_optimizer.step()
        for param in self.critic.parameters(): # for plot gradient
            if param.grad is not None:
                critic_grad_norm += param.grad.data.norm(2).item() ** 2
        self.critic_optimizer.zero_grad()
        
        # Update actor
        predicted_action = self.actor(batch_states)
        actor_loss = self.critic(batch_states, predicted_action)
        actor_loss = -torch.mean(actor_loss)
        actor_loss.backward() # start a new gradient
        self.actor_optimizer.step() # update gradient to actor from actor_optimizer
        for param in self.actor.parameters(): # for plot gradient
            if param.grad is not None:
                actor_grad_norm += param.grad.data.norm(2).item() ** 2
        self.actor_optimizer.zero_grad() # clear the gradient

        critic_grad_norm = np.sqrt(critic_grad_norm) # for plot gradient
        actor_grad_norm = np.sqrt(actor_grad_norm) # for plot gradient
        
        if t_slot % 5000 == 0 and t_slot != 0:
            self.critic_losses.append(critic_loss.item())
            self.actor_losses.append(actor_loss.item())
            self.gradients.append((critic_grad_norm, actor_grad_norm)) # for plot gradient
        
        # Update target networks with soft updates
        self.soft_update(self.actor_target, self.actor)
        self.soft_update(self.critic_target, self.critic)
def soft_update(self, target, online):
        for target_param, online_param in zip(target.parameters(), online.parameters()):
            target_param.data = (1 - self.TAU) * target_param.data + self.TAU * online_param.data

Then this is the class Critic.

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Critic network architecture
        self.w1_s = nn.Parameter(torch.randn(state_dim, 64))
        self.w1_a = nn.Parameter(torch.randn(action_dim, 64))

        self.fc = nn.Linear(64, 1)

    def forward(self, state, action):        
        state = state.to(self.device)
        action = action.to(self.device)
        # Forward pass of the critic network
        x_state = torch.matmul(state, self.w1_s)
        x_action = torch.matmul(action, self.w1_a)
        x = torch.relu(x_state + x_action)
        q_value = self.fc(x)
        
        return q_value

The following is my training process:
(1 episode is equals to 5000 time_slot)
choose_action() in ddpg.py (every time_slot) → choose_action() in meta.py (every time_slot) → [after 1 episode] learn() in meta.py (every 500 time_slot) → [after 2 episode] leran() in ddpg.py (every 500 time_slot)
The program fail when entering leran() in ddpg.py.

And this is my code of choose_action():

def choose_action(self, state, state_mcs):
        if torch.cuda.is_available():
            state = state.to(self.device)
            state_mcs = state_mcs.to(self.device)
        
        self.state_concatenate = torch.cat((state, state_mcs))
        
        action = self.actor(self.state_concatenate)  # Forward pass through the actor network
        return action

Best regards,
Ning