I am training models, and I am noticing that it is not choosing the optimal policy. I have made several modifications, however, now I’m noticing that the Q-table values are becoming Nan; any insight as to why?
Can you share some more information?
- Loss function;
- Optimizer and lr;
- Reward function;
- Batch size.
I am using the default TD3 algorithm & parameters provided by the author:
# Select action according to policy and add clipped noise noise = ( torch.randn_like(action) * self.policy_noise ).clamp(-self.noise_clip, self.noise_clip) next_action = ( self.actor_target(next_state) + noise ).clamp(-self.max_action, self.max_action) # Compute the target Q value target_Q1, target_Q2 = self.critic_target(next_state, next_action) target_Q = torch.min(target_Q1, target_Q2) target_Q = reward + not_done * self.discount * target_Q # Get current Q estimates current_Q1, current_Q2 = self.critic(state, action) # Compute critic loss critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Delayed policy updates if self.total_it % self.policy_freq == 0: # Compute actor losse actor_loss = -self.critic.Q1(state, self.actor(state)).mean() # Optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # Update the frozen target models for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
The lr is 3e-4
The reward function is:
if self.avg_hit >= 0.80 and self.avg_use >= 0.80:
reward = 9
elif self.avg_hit < 0.70 or self.avg_use < 0.70:
reward = -9
reward = 0
The beach size is 256
I would like to add that it trains after every episode instead of every step and appears to be getting stuck in the actions that keep it in the zero rang rather move toward 9