I am training models, and I am noticing that it is not choosing the optimal policy. I have made several modifications, however, now I’m noticing that the Q-table values are becoming Nan; any insight as to why?
Can you share some more information?
- Loss function;
- Optimizer and lr;
- Reward function;
- Batch size.
I am using the default TD3 algorithm & parameters provided by the author:
with torch.no_grad():
# Select action according to policy and add clipped noise
noise = (
torch.randn_like(action) * self.policy_noise
).clamp(-self.noise_clip, self.noise_clip)
next_action = (
self.actor_target(next_state) + noise
).clamp(-self.max_action, self.max_action)
# Compute the target Q value
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + not_done * self.discount * target_Q
# Get current Q estimates
current_Q1, current_Q2 = self.critic(state, action)
# Compute critic loss
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Delayed policy updates
if self.total_it % self.policy_freq == 0:
# Compute actor losse
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update the frozen target models
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
The lr is 3e-4
The reward function is:
if self.avg_hit >= 0.80 and self.avg_use >= 0.80:
reward = 9
elif self.avg_hit < 0.70 or self.avg_use < 0.70:
reward = -9
else:
reward = 0
The beach size is 256
I would like to add that it trains after every episode instead of every step and appears to be getting stuck in the actions that keep it in the zero rang rather move toward 9