Retain_graph and Meta-Gradient issue in A2C with intrinsic reward

I’m trying to calculate \grad_{eta} J_ex = \grad_{eta} G_ex * pi_new / pi_old. Which is the intrinsic_loss in the code below.

But in order to have pi_new I have to modify the params of the policy. So I can’t calculate this grad without running into issues. I can’t use retain_graph because I need to update the policy params to have pi_new and I can’t calculate this grad directly since it needs the graph up to that point.

I have seen one code where someone uses the chain rule and calculate the gradients by “hand” and then backpropagates this new gradient. But this needs a for loop and looks very ugly and it really looks like a quick-fix solution instead of a standard one. Is there any way of computing the intrinsic loss in a more elegant way?

  • I suppose higher could help, but I never used, so I really don’t how it could help in this case.
  • Another solution could be to access the grads of actor_loss, create a new_policy without calling optimizer.step() with these grads and use retain_graph=True and only use optimizer.step() in the end. But again, this would need to pass it by hand.

Thanks

    def update(self, next_obs):
        # Get tranining batch:
        states, actions, rewards_ex, dones = self.memory.sample()
        
         # Get next value:
        next_value = self.get_value(next_obs)
        next_value_ex = self.lifetime_return(torch.from_numpy(next_obs).float()).detach()
        
        rewards_ex = torch.tensor(rewards_ex, dtype=torch.float)
        actions = torch.tensor(actions, dtype=torch.int64)
        states = torch.tensor(states, dtype=torch.float)
        dones = torch.tensor(dones, dtype=torch.float)

        # Compute returns:
        rewards = rewards_ex + self.lmbd * self.intrinsic_reward(states).gather(dim=1, index=actions.unsqueeze(-1)).flatten()
            
        returns_ex, returns = self.calculate_returns(next_value, next_value_ex, rewards, rewards_ex, dones)
        
        ##### Critic loss #####
        # Advantage:
        values = self.critic(states)
        adv = returns - values
        critic_loss = adv.pow(2).mean()*self.v_coef
        
        # Update Critic:
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
                
        # Actor loss
        entropies, log_probs = self.actor.get_log_prob_entropy(states, actions)
        policy_loss = - (log_probs * adv.detach()).mean()
        entropy_loss = entropies.mean()
        actor_loss = policy_loss  - self.entropy_coef * entropy_loss
        
        # update actor:
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Update Extrinsic Critic:
        # Adavantage:
        values_ex = self.lifetime_return(states)
        adv_ex = returns_ex - values_ex
        ex_critic_loss = adv_ex.pow(2).mean()*self.v_coef
        
        self.lifetime_optimizer.zero_grad()
        ex_critic_loss.backward()
        self.lifetime_optimizer.step()
        
        # Intrinsic reward update:
        entropies_new, log_probs_new = self.actor.get_log_prob_entropy(states, actions)
        ratio = torch.exp(log_probs_new - log_probs)
        intrinsic_loss = -(adv_ex.detach() * ratio).mean()
        
        self.reward_optimizer.zero_grad()
        intrinsic_loss.backward()    # This is where the error lies
        self.reward_optimizer.step()
        
        self.old_policy.load_state_dict(self.actor.state_dict())
        
        # reset memory:
        self.memory.reset()
        
        return critic_loss.detach().numpy(), policy_loss.detach().numpy()

Hi @mateuspontesm
If i inderstand your problem correctly, you are trying to backpropagate through the update of your optimizer, am I right?
If that is the case I would suggest to use something like TorchOpt.
Here’s a code snipped of what it can do, hope it’s self-explanatory and spot on:

>>> import torchopt
>>> import torch
>>> from torchopt import MetaSGD
>>> net = torch.nn.Linear(3, 1)
>>> x = torch.randn(3)
>>> print(net.weight) # should be a parameter
Parameter containing:
tensor([[-0.0786, -0.1337, -0.4710]], requires_grad=True)
>>> optim_inner = MetaSGD(net, 0.001)
>>> sd = torchopt.extract_state_dict(net)
>>> loss_inner = net(x)
>>> optim_inner.step(loss_inner)
>>> print(net.weight) # not a parameter anymore
tensor([[-0.0467, -0.1747, -0.0833]], grad_fn=<AddBackward0>)
>>> loss_outer = net(x)
>>> loss_meta = loss_inner - loss_outer
>>> loss_meta  # different than 0
tensor([0.0062], grad_fn=<SubBackward0>)
>>> loss_meta.backward()
>>> torchopt.recover_state_dict(net, sd)  # replaces the original parameter at its place
1 Like

Thanks for the answer @vmoens. And yes, you are correct.
I have an intrinsic reward network which is used to update the Value function critic of the A2C which estimates the mixed return (intrinsic + extrinsic return). The critic is used to update the actor using A2C PG. And this is where the problem lies: we should update the intrinsic reward network with respect to the effect it had in the actor update, but instead of using the mixed (intrinsic and extrinsic) Value function, we use only the extrinsic Adv (Adv_ex) for the intrinsic reward update. This extrinsic value function is updated normally using the MSE with the extrinsic returns.

So, when I update the intrinsic reward, i need the graph going up two steps, one for the actor update and one for the critic update so that it can correctly see the effect it had. The “ratio” there is just Importance Sampling in order to avoid needing samples of the new policy.

It seems that torchopt can help, so I’m currently going through the tutorials to implement this, but my guess from your code is that I should use MetaAdam for the A2C Value critic and for the policy, but the intrinsic reward and the extrinsic value function could be normal Adam since those params are not used in the “outer loop”. Also, I should use stopgradient on those MetaAdams at the end of the update function so that in the next step it goes in correctly.