I’m trying to calculate \grad_{eta} J_ex = \grad_{eta} G_ex * pi_new / pi_old. Which is the intrinsic_loss in the code below.
But in order to have pi_new I have to modify the params of the policy. So I can’t calculate this grad without running into issues. I can’t use retain_graph because I need to update the policy params to have pi_new and I can’t calculate this grad directly since it needs the graph up to that point.
I have seen one code where someone uses the chain rule and calculate the gradients by “hand” and then backpropagates this new gradient. But this needs a for loop and looks very ugly and it really looks like a quick-fix solution instead of a standard one. Is there any way of computing the intrinsic loss in a more elegant way?
- I suppose higher could help, but I never used, so I really don’t how it could help in this case.
- Another solution could be to access the grads of actor_loss, create a new_policy without calling optimizer.step() with these grads and use retain_graph=True and only use optimizer.step() in the end. But again, this would need to pass it by hand.
Thanks
def update(self, next_obs):
# Get tranining batch:
states, actions, rewards_ex, dones = self.memory.sample()
# Get next value:
next_value = self.get_value(next_obs)
next_value_ex = self.lifetime_return(torch.from_numpy(next_obs).float()).detach()
rewards_ex = torch.tensor(rewards_ex, dtype=torch.float)
actions = torch.tensor(actions, dtype=torch.int64)
states = torch.tensor(states, dtype=torch.float)
dones = torch.tensor(dones, dtype=torch.float)
# Compute returns:
rewards = rewards_ex + self.lmbd * self.intrinsic_reward(states).gather(dim=1, index=actions.unsqueeze(-1)).flatten()
returns_ex, returns = self.calculate_returns(next_value, next_value_ex, rewards, rewards_ex, dones)
##### Critic loss #####
# Advantage:
values = self.critic(states)
adv = returns - values
critic_loss = adv.pow(2).mean()*self.v_coef
# Update Critic:
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Actor loss
entropies, log_probs = self.actor.get_log_prob_entropy(states, actions)
policy_loss = - (log_probs * adv.detach()).mean()
entropy_loss = entropies.mean()
actor_loss = policy_loss - self.entropy_coef * entropy_loss
# update actor:
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update Extrinsic Critic:
# Adavantage:
values_ex = self.lifetime_return(states)
adv_ex = returns_ex - values_ex
ex_critic_loss = adv_ex.pow(2).mean()*self.v_coef
self.lifetime_optimizer.zero_grad()
ex_critic_loss.backward()
self.lifetime_optimizer.step()
# Intrinsic reward update:
entropies_new, log_probs_new = self.actor.get_log_prob_entropy(states, actions)
ratio = torch.exp(log_probs_new - log_probs)
intrinsic_loss = -(adv_ex.detach() * ratio).mean()
self.reward_optimizer.zero_grad()
intrinsic_loss.backward() # This is where the error lies
self.reward_optimizer.step()
self.old_policy.load_state_dict(self.actor.state_dict())
# reset memory:
self.memory.reset()
return critic_loss.detach().numpy(), policy_loss.detach().numpy()