In this code block I specify the two networks:
class Actor(nn.Module):
def __init__(self, obs_size, action_size, hidden_size,
activation=nn.Tanh()):
super(Actor, self).__init__()
self.action = nn.Sequential(
layer_init(nn.Linear(obs_size, hidden_size)),
activation,
layer_init(nn.Linear(hidden_size, hidden_size)),
activation,
layer_init(nn.Linear(hidden_size, action_size), std=0.01),
)
def forward(self):
raise NotImplementedError
def get_action(self, state, action = None):
logits = self.action(state)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy()
class Critic(nn.Module):
def __init__(self, obs_size: int, hidden_size, activation=nn.Tanh()):
"""Initialize."""
super(Critic, self).__init__()
self.value = nn.Sequential(
layer_init(nn.Linear(obs_size, hidden_size)),
activation,
layer_init(nn.Linear(hidden_size, hidden_size)),
activation,
layer_init(nn.Linear(hidden_size, 1), std=1.)
)
def forward(self):
raise NotImplementedError
def get_value(self, state):
return self.value(state)
the networks call:
self.actor = Actor(self.obs_size, self.action_size, self.hidden_size).to(self.device)
self.critic = Critic(self.obs_size, self.hidden_size).to(self.device)
optimizers:
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.critic_lr)
update of the networks function:
def update_model(self, next_state: np.ndarray):
last_value = self.critic.get_value(next_state.to(self.device)).reshape(1, -1)
returns, advantages = compute_gae(last_value, self.rewards, self.masks, self.values,
self.gamma, self.lam, self.device)
# squeeze
states_traj = self.states.squeeze()
log_probs_traj = self.log_probs.squeeze()
actions_traj = self.actions.squeeze()
advantages_traj = advantages.squeeze()
returns_traj = returns.squeeze()
values_traj = self.values.squeeze()
ids = np.arange(self.trajectory_size)
for epoch in range(self.epochs):
np.random.shuffle(ids)
for start in range(0, self.trajectory_size, self.mini_batch_size):
end = start + self.mini_batch_size
minibatch_ind = ids[start:end]
advantages_minib = advantages_traj[minibatch_ind]
if self.normalize_adv:
advantages_minib = (advantages_minib - advantages_minib.mean()) / (advantages_minib.std() + 1e-8)
_, newlogproba, entropy = self.actor.get_action(states_traj[minibatch_ind],
actions_traj.long()[minibatch_ind])
ratio = (newlogproba - log_probs_traj[minibatch_ind]).exp()
# actor loss
surr_loss = -advantages_minib * ratio
clipped_surr_loss = -advantages_minib * torch.clamp(ratio, 1 - self.epsilon,
1 + self.epsilon)
actor_loss_max = torch.max(surr_loss, clipped_surr_loss).mean()
entropy_loss = entropy.mean()
actor_loss = actor_loss_max - self.entropy_weight * entropy_loss
# critic_loss
new_values = self.critic.get_value(states_traj[minibatch_ind]).view(-1)
if self.clipped_value_loss:
critic_loss_unclipped = (new_values - returns_traj[minibatch_ind]) ** 2
value_clipped = values_traj[minibatch_ind] + torch.clamp(new_values -
values_traj[minibatch_ind],
- self.epsilon, self.epsilon)
critic_loss_clipped = (value_clipped - returns_traj[minibatch_ind]) ** 2
critic_loss_max = torch.max(critic_loss_clipped, critic_loss_unclipped)
critic_loss = 0.5 * critic_loss_max.mean() * self.critic_weight
else:
critic_loss = 0.5 * (new_values - returns_traj[minibatch_ind] ** 2).mean() * self.critic_weight
loss = actor_loss + critic_loss
# critic backward implementation
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optimizer.step()
# actor backward implementation
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optimizer.step()
return actor_loss, critic_loss