Hello, I’m programming my first PPO agent and have encountered an error during training.
Specifically, I’m getting a RuntimeError
in PyTorch during the second backward pass, which suggests that an in-place operation has modified a tensor required for gradient computation. Interestingly, the error only occurs after the first backward pass works fine. The issue is temporarily resolved when I don’t use old_log_probs
during the loss computation. I’ve tried enabling torch.autograd.set_detect_anomaly(True)
, but unfortunately, it didn’t lead to any additional insights.
Here’s the error message:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 9]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
I am working on a large-scale warehouse simulation, where an AI agent determines the best storage locations for products. The simulation uses a Proximal Policy Optimization (PPO) algorithm, with both policy and value networks involved in training. I’ve already tested the rest of the code, including reward calculations and other core functions, using a simple CNN and DDQN agent, both of which work fine.
Due to the size of the project, I can only share a portion of the code, focusing on the training loop and the respective parts that are causing issues. If you need more context to help with error detection, feel free to let me know!
Here’s the respective code:
def _predict_policy_network(self, agent_state):
"""Pass state through shared encoder and policy network."""
return self.policy_network(agent_state).squeeze(0)
def _predict_value_network(self, states):
"""Pass state through shared encoder and value network."""
return self.value_network(states).squeeze(-1)
# ---------------- Decision Making and Network Training ----------------- #
def get_action(self, agent_state):
"""
Determines the next action based on the agent's state.
Returns:
- action: The selected action (index of the chosen storage location).
- log_prob: The logarithmic probability of selecting that action.
- logits: The predictions for selecting every action.
"""
# Reshape agent state to [1, 3, input_dim] and move to device
agent_state = agent_state.unsqueeze(0).to(self.device)
# Ensure gradients are enabled during training
if self.phase == Phase.TRAINING.value:
logits = self._predict_policy_network(agent_state)
elif self.phase == Phase.VALIDATION.value:
# Use no_grad() for validation
with torch.no_grad():
logits = self._predict_policy_network(agent_state)
# Fetch probabilities for easier logging
probabilities = torch.softmax(logits, dim=-1)
# Sample an action based on policy logits
distribution = dist.Categorical(logits=logits)
action = distribution.sample()
log_prob = distribution.log_prob(action)
logging.info(f"\nLogits: \n{logits} \nProbabilities: \n{probabilities} \nChosen action: {action.item()}")
return action.item(), log_prob, logits
def train(self, training_samples):
# Gather data from training samples, including precomputed advantages and returns
states, actions, old_log_probs, old_logits, advantages, returns = self._fetch_data_from_samples(training_samples)
for epoch in range(self.num_epochs):
# Get policy logits and compute new log probabilities
new_logits = self._predict_policy_network(states)
distribution = dist.Categorical(logits=new_logits)
new_log_probs = distribution.log_prob(actions)
# Entropy regularization for exploration
entropy = distribution.entropy().mean()
logging.info(f"Entropy = {entropy}")
# Compute losses using precomputed advantages and returns
policy_loss, policy_loss_detached, value_loss = self._compute_losses(
new_log_probs,
old_log_probs,
advantages,
returns,
states,
entropy
)
# KL-divergence calculation for early stopping
kl_divergence = self._compute_kl_divergence(old_logits, new_logits)
# Compute the total loss with KL as regularization term
total_loss = policy_loss + self.value_coef * value_loss + self.kl_coef * kl_divergence
# Backpropagation with total_loss
self.policy_optimizer.zero_grad()
self.value_optimizer.zero_grad()
total_loss.backward(retain_graph=True)
self.apply_grad_clipping(self.policy_network)
self.apply_grad_clipping(self.value_network)
self.policy_optimizer.step()
self.value_optimizer.step()
self._adjust_entropy_coef()
# Pass detached values for postprocessing
return policy_loss_detached, value_loss.detach(), kl_divergence.detach(), entropy.detach()
def _fetch_data_from_samples(self, training_samples):
batch_size = len(training_samples)
# Initialize tensors for batch processing
states = torch.empty((batch_size, *training_samples[0][0].shape), device=self.device)
actions = torch.empty((batch_size,), dtype=torch.long, device=self.device)
old_log_probs = torch.empty((batch_size,), dtype=torch.float32, device=self.device)
old_logits = torch.empty((batch_size, *training_samples[0][3].shape), device=self.device)
advantages = torch.empty((batch_size,), dtype=torch.float32, device=self.device)
returns = torch.empty((batch_size,), dtype=torch.float32, device=self.device)
# Load the samples into the batch tensors
for i, sample in enumerate(training_samples):
states[i] = sample[0]
actions[i] = sample[1]
old_log_probs[i] = sample[2].clone()
old_logits[i] = sample[3].clone()
advantages[i] = sample[4]
returns[i] = sample[5]
normalized_advantages, normalized_returns = self._normalize_data(advantages, returns)
return states, actions, old_log_probs, old_logits, normalized_advantages, normalized_returns
def _normalize_data(self, advantages, returns):
# Normalize advantages
advantage_mean = advantages.mean()
advantage_std = advantages.std() + 1e-8 # Add a small constant to avoid division by zero
normalized_advantages = (advantages - advantage_mean) / advantage_std
# Normalize returns
return_mean = returns.mean()
return_std = returns.std() + 1e-8 # Add a small constant to avoid division by zero
normalized_returns = (returns - return_mean) / return_std
return normalized_advantages, normalized_returns
def _compute_losses(self, new_log_probs, old_log_probs, advantages, returns, states, entropy):
# Calculate the policy (actor) loss with clipping
ratio = torch.exp(new_log_probs - old_log_probs)
clipped_ratio = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio)
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
# Detach for postporcessing
policy_loss_detached = policy_loss.detach().clone()
# Add entropy regularization to the policy loss
policy_loss = policy_loss - self.entropy_coef * entropy
# Calculate the value (critic) loss separately
values = self._predict_value_network(states)
value_loss = nn.MSELoss()(values, returns)
return policy_loss, policy_loss_detached, value_loss
def _compute_and_save_advantages_and_returns(self, product_id, next_state):
"""
Computes the advantage and expected return for a given action in a given state.
Arguments:
- agent_state: The current state of the agent.
- action: The action for which to compute the advantage and return.
Returns:
- advantage: Advantage of the selected action.
- return_: Expected return of the selected action.
"""
if product_id in self.training_dataset.pending_samples:
# Fetch previous agent state and reward
agent_state = self.training_dataset.pending_samples[product_id]['state'].unsqueeze(0).to(self.device)
reward = self.training_dataset.pending_samples[product_id]['reward']
next_state = next_state.unsqueeze(0).to(self.device)
# Calculate value (critic) for current state and next state
value = self._predict_value_network(agent_state)
next_value = self._predict_value_network(next_state)
# Calculate delta for advantage
delta = reward + GAMMA * next_value - value
advantage = delta + (GAMMA * LAMBDA * delta)
return_ = reward + GAMMA * next_value
# Append training dataset with calculated values
self.training_dataset.pending_samples[product_id]['advantage'] = advantage.detach()
self.training_dataset.pending_samples[product_id]['return'] = return_.detach()
logging.info(f"\nAdvantage for previous state: {round(advantage.item(), 3)} \nReturn for previous state: {round(return_.item(), 3)}")
# Move pending sample to complete samples dict
self.training_dataset._add_complete_sample(product_id)
Additional Information:
- PyTorch Version: 2.3.1
- Operating System: Linux (WSL2)
- Kernel Version: 5.15.153.1-microsoft-standard-WSL2
- Architecture: x86_64
Any insights into why this happens or how I can fix it would be greatly appreciated!