I know trying to debug someone else’s code can be very difficult, so I am going to try to make it as easy as possible.
I have been trying to debug it for over 1 week and have not been able to solve.
My truncated code:
TORCH_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class centralized_ddpg_agent_actor(torch.nn.Module):
def __init__(self, action_space_size, observation_state_size):
super().__init__()
self.linear1 = torch.nn.Linear(observation_state_size, 128, device=TORCH_DEVICE)
self.linear2 = torch.nn.Linear(128, 256, device=TORCH_DEVICE)
self.linear3 = torch.nn.Linear(256, action_space_size, device=TORCH_DEVICE)
def forward(self, observations):
output = torch.tanh(self.linear1(observations))
output = torch.tanh(self.linear2(output))
output = torch.tanh(self.linear3(output))
return output
class centralized_ddpg_agent_critic(torch.nn.Module):
def __init__(self, action_space_size, observation_state_size):
super().__init__()
self.linear1 = torch.nn.Linear(action_space_size + observation_state_size, 128, device=TORCH_DEVICE)
self.linear2 = torch.nn.Linear(128, 256, device=TORCH_DEVICE)
self.linear3 = torch.nn.Linear(256, 1, device=TORCH_DEVICE)
def forward(self, observations , actions):
output = torch.tanh(self.linear1(torch.cat((observations, actions), dim = 1)))
output = torch.tanh(self.linear2(output))
value = torch.tanh(self.linear3(output))
return value
#source: https://github.com/ghliu/pytorch-ddpg/blob/master/util.py
def soft_update_target_network(target, source, tau):
assert tau >= 0 and tau <= 1
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
class DDPG_model():
def __init__(self, num_actions, num_states, yaml_config):
self.learning_rate = yaml_config['DDPG']['gamma']
self.target_rate = yaml_config['DDPG']['tau']
self.mini_batch_size = yaml_config['DDPG']['N']
self.noise_variance = yaml_config['DDPG']['noise_var']
self.actor = centralized_ddpg_agent_actor(num_actions, num_states) # mu
self.target_actor = copy.deepcopy(self.actor) # mu'
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())#TODO check learning rate
self.critic = centralized_ddpg_agent_critic(num_actions, num_states) # q
self.target_critic = copy.deepcopy(self.critic) # q'
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
self.critic_criterion = torch.nn.MSELoss()
experience_replay_buffer_size = yaml_config['DDPG']['experience_replay_buffer_size']
self.erb = experience_replay_buffer(experience_replay_buffer_size)
def query_actor(self, state):
return torch.clamp(self.actor(state) + torch.randn(num_actions).to(TORCH_DEVICE)*(self.noise_variance**0.5), min = env.action_space.low[0], max = env.action_space.high[0])
def train_model_step(self):
if len(self.erb.buffer) < self.mini_batch_size:
return
#calulate input date for optimaizers from sampled mini-batch
old_state_batch, actions_batch, reward_batch, new_state_batch = self.erb.sample_batch_and_split(self.mini_batch_size)
q = self.critic(old_state_batch, actions_batch)
y = reward_batch + self.learning_rate * self.target_critic(new_state_batch, self.target_actor(new_state_batch))
#update critic
self.critic_optimizer.zero_grad()
critic_loss = self.critic_criterion(q, y)
critic_loss.backward()
self.critic_optimizer.step()
#update actor
self.actor_optimizer.zero_grad()
policy_loss = (-self.critic(old_state_batch, self.actor(old_state_batch))).mean()
print(policy_loss) #quickly converges to ±1
policy_loss.backward()
self.actor_optimizer.step()
#update target networks
soft_update_target_network(self.target_actor, self.actor, self.target_rate)
soft_update_target_network(self.target_critic, self.critic, self.target_rate)
if __name__ == "__main__":
config = yaml.safe_load(open('config.yaml', 'r'))
env = gymnasium.make('Pendulum-v1')
env_eval = gymnasium.make('Pendulum-v1', render_mode='human')
... #create evaluate file
num_actions = env.action_space.shape[0]
num_states = env.observation_space.shape[0]
DDPG = DDPG_model(num_actions, num_states, config)
for episode in range(config['domain']['episodes']):
cur_state = torch.tensor(env.reset()[0], dtype=torch.float32).to(TORCH_DEVICE)
for step in range(env.spec.max_episode_steps):
actions = DDPG.query_actor(cur_state)
new_state, reward, is_terminal, is_truncated, info = env.step(actions.tolist())
DDPG.erb.add_experience(old_state = cur_state, actions= actions.detach(), reward = reward, new_state = torch.tensor(new_state, dtype=torch.float32).to(TORCH_DEVICE), is_terminal = is_terminal or is_truncated)
cur_state = torch.tensor(new_state, dtype=torch.float32).to(TORCH_DEVICE)
DDPG.train_model_step()
if is_terminal:
break
...#evaluate episode
The agents fail to learn at even the simplest continuous control gymnasium domain classic_control/ Pendulum.
The policy_loss
(in ddpg.train_model_step()
) quickly converges (in 200ish steps) to either +1 or -1 regardless of state, which is because the critic converges to and output of -1 or +1 regardless of input
I have tried tweaking hyperparameters with no results.
Feel free to ask for any clarification
Thanks for just reading this, I appreciate it.
My full code can be found at: