RuntimeError: Gradient computation modified by an inplace operation

Hello, I’m programming my first PPO agent and have encountered an error during training.

Specifically, I’m getting a RuntimeError in PyTorch during the second backward pass, which suggests that an in-place operation has modified a tensor required for gradient computation. Interestingly, the error only occurs after the first backward pass works fine. The issue is temporarily resolved when I don’t use old_log_probs during the loss computation. I’ve tried enabling torch.autograd.set_detect_anomaly(True), but unfortunately, it didn’t lead to any additional insights.

Here’s the error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 9]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I am working on a large-scale warehouse simulation, where an AI agent determines the best storage locations for products. The simulation uses a Proximal Policy Optimization (PPO) algorithm, with both policy and value networks involved in training. I’ve already tested the rest of the code, including reward calculations and other core functions, using a simple CNN and DDQN agent, both of which work fine.

Due to the size of the project, I can only share a portion of the code, focusing on the training loop and the respective parts that are causing issues. If you need more context to help with error detection, feel free to let me know!

Here’s the respective code:

def _predict_policy_network(self, agent_state):
    """Pass state through shared encoder and policy network."""
    return self.policy_network(agent_state).squeeze(0)

def _predict_value_network(self, states):
    """Pass state through shared encoder and value network."""
    return self.value_network(states).squeeze(-1)

# ---------------- Decision Making and Network Training ----------------- #

def get_action(self, agent_state):
    """
    Determines the next action based on the agent's state.
    
    Returns:
        - action: The selected action (index of the chosen storage location).
        - log_prob: The logarithmic probability of selecting that action.
        - logits: The predictions for selecting every action.
    """
    # Reshape agent state to [1, 3, input_dim] and move to device
    agent_state = agent_state.unsqueeze(0).to(self.device)

    # Ensure gradients are enabled during training
    if self.phase == Phase.TRAINING.value:
        logits = self._predict_policy_network(agent_state)
    
    elif self.phase == Phase.VALIDATION.value:
        # Use no_grad() for validation
        with torch.no_grad():
            logits = self._predict_policy_network(agent_state)

    # Fetch probabilities for easier logging
    probabilities = torch.softmax(logits, dim=-1)

    # Sample an action based on policy logits
    distribution = dist.Categorical(logits=logits)
    action = distribution.sample()
    log_prob = distribution.log_prob(action)

    logging.info(f"\nLogits: \n{logits} \nProbabilities: \n{probabilities} \nChosen action: {action.item()}")

    return action.item(), log_prob, logits

def train(self, training_samples):
    # Gather data from training samples, including precomputed advantages and returns
    states, actions, old_log_probs, old_logits, advantages, returns = self._fetch_data_from_samples(training_samples)

    for epoch in range(self.num_epochs):
        # Get policy logits and compute new log probabilities
        new_logits = self._predict_policy_network(states)
        distribution = dist.Categorical(logits=new_logits)
        new_log_probs = distribution.log_prob(actions)

        # Entropy regularization for exploration
        entropy = distribution.entropy().mean()
        logging.info(f"Entropy = {entropy}")

        # Compute losses using precomputed advantages and returns
        policy_loss, policy_loss_detached, value_loss = self._compute_losses(
            new_log_probs, 
            old_log_probs, 
            advantages, 
            returns, 
            states,
            entropy
        )

        # KL-divergence calculation for early stopping
        kl_divergence = self._compute_kl_divergence(old_logits, new_logits)

        # Compute the total loss with KL as regularization term
        total_loss = policy_loss + self.value_coef * value_loss + self.kl_coef * kl_divergence

        # Backpropagation with total_loss
        self.policy_optimizer.zero_grad()
        self.value_optimizer.zero_grad()
        total_loss.backward(retain_graph=True)
        self.apply_grad_clipping(self.policy_network)
        self.apply_grad_clipping(self.value_network)
        self.policy_optimizer.step()
        self.value_optimizer.step()

    self._adjust_entropy_coef()

    # Pass detached values for postprocessing
    return policy_loss_detached, value_loss.detach(), kl_divergence.detach(), entropy.detach()

def _fetch_data_from_samples(self, training_samples):
    batch_size = len(training_samples)

    # Initialize tensors for batch processing
    states = torch.empty((batch_size, *training_samples[0][0].shape), device=self.device)
    actions = torch.empty((batch_size,), dtype=torch.long, device=self.device)
    old_log_probs = torch.empty((batch_size,), dtype=torch.float32, device=self.device)
    old_logits = torch.empty((batch_size, *training_samples[0][3].shape), device=self.device)
    advantages = torch.empty((batch_size,), dtype=torch.float32, device=self.device)
    returns = torch.empty((batch_size,), dtype=torch.float32, device=self.device)

    # Load the samples into the batch tensors
    for i, sample in enumerate(training_samples):
        states[i] = sample[0]
        actions[i] = sample[1]
        old_log_probs[i] = sample[2].clone()
        old_logits[i] = sample[3].clone()
        advantages[i] = sample[4]
        returns[i] = sample[5]

    normalized_advantages, normalized_returns = self._normalize_data(advantages, returns)

    return states, actions, old_log_probs, old_logits, normalized_advantages, normalized_returns

def _normalize_data(self, advantages, returns):
    # Normalize advantages
    advantage_mean = advantages.mean()
    advantage_std = advantages.std() + 1e-8  # Add a small constant to avoid division by zero
    normalized_advantages = (advantages - advantage_mean) / advantage_std

    # Normalize returns
    return_mean = returns.mean()
    return_std = returns.std() + 1e-8  # Add a small constant to avoid division by zero
    normalized_returns = (returns - return_mean) / return_std

    return normalized_advantages, normalized_returns

def _compute_losses(self, new_log_probs, old_log_probs, advantages, returns, states, entropy):
    # Calculate the policy (actor) loss with clipping
    ratio = torch.exp(new_log_probs - old_log_probs)
    clipped_ratio = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio)
    policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()

    # Detach for postporcessing
    policy_loss_detached = policy_loss.detach().clone()
    
    # Add entropy regularization to the policy loss
    policy_loss = policy_loss - self.entropy_coef * entropy

    # Calculate the value (critic) loss separately
    values = self._predict_value_network(states)
    value_loss = nn.MSELoss()(values, returns)

    return policy_loss, policy_loss_detached, value_loss

def _compute_and_save_advantages_and_returns(self, product_id, next_state):
    """
    Computes the advantage and expected return for a given action in a given state.
    
    Arguments:
        - agent_state: The current state of the agent.
        - action: The action for which to compute the advantage and return.
    
    Returns:
        - advantage: Advantage of the selected action.
        - return_: Expected return of the selected action.
    """
    if product_id in self.training_dataset.pending_samples:
        # Fetch previous agent state and reward
        agent_state = self.training_dataset.pending_samples[product_id]['state'].unsqueeze(0).to(self.device)
        reward = self.training_dataset.pending_samples[product_id]['reward']
        next_state = next_state.unsqueeze(0).to(self.device)

        # Calculate value (critic) for current state and next state
        value = self._predict_value_network(agent_state)
        next_value = self._predict_value_network(next_state)

        # Calculate delta for advantage
        delta = reward + GAMMA * next_value - value
        advantage = delta + (GAMMA * LAMBDA * delta)
        return_ = reward + GAMMA * next_value

        # Append training dataset with calculated values
        self.training_dataset.pending_samples[product_id]['advantage'] = advantage.detach()
        self.training_dataset.pending_samples[product_id]['return'] = return_.detach()

        logging.info(f"\nAdvantage for previous state: {round(advantage.item(), 3)} \nReturn for previous state: {round(return_.item(), 3)}")

        # Move pending sample to complete samples dict
        self.training_dataset._add_complete_sample(product_id)

Additional Information:

  • PyTorch Version: 2.3.1
  • Operating System: Linux (WSL2)
  • Kernel Version: 5.15.153.1-microsoft-standard-WSL2
  • Architecture: x86_64

Any insights into why this happens or how I can fix it would be greatly appreciated!