Using backward of function for multiple networks

Hi, I have implemented an actor-critic method with linear networks on both actor and critic also I am using a custom function for loss calculation which is loss function used in PPO. The reward function of environment is different from the loss function of networks.
This is the loss calculation:

 dist = self.actor(states)
critic_value = self.critic(states)
critic_value = torch.squeeze(critic_value)

new_probs = dist.log_prob(actions)
prob_ratio = new_probs.exp() / old_probs.exp()
 weighted_probs = advantage[batch.item()] * prob_ratio
weighted_clipped_probs = torch.clamp(prob_ratio, 1-self.policy_clip,
                                                   1+self.policy_clip)*advantage[batch.item()]
actor_loss = -torch.min(weighted_probs,
                                weighted_clipped_probs).mean()

returns = advantage[batch.item()] + values[batch.item()]
critic_loss = (returns-critic_value)**2
critic_loss = critic_loss.mean()


total_loss = actor_loss + 0.5*critic_loss

and then model omtimization:

self.actor.optimizer.zero_grad()
 self.critic.optimizer.zero_grad()
total_loss.backward()
self.actor.optimizer.step()
self.critic.optimizer.step()

Now my questions:

  1. I want to know is it correct to use backward in this way?
  2. Is the backward now returning grad for both actor and critic networks and the both networks are optimizing?
  3. If I add a feature extractor network before inputting the features to network, and input the extracted features to the actor and critic, will this backward works for feature extractor network?

@ptrblck Would you mind please helping me?

  1. Yes, calling total_loss.backward() looks correct. I don’t know if the implemented training logic fits your use case, as I’m not familiar with it.

  2. Yes, as long as no tensors are (accidentally) detached. To verify it, check the .grad attributes of the parameters of both models before and after the first backward() call. Before the call, they should be set to None, afterwards to a valid gradient tensor.

  3. Yes and the same applies as in 2. If you are not detaching the tensors, the feature extractor would also get gradients.

1 Like