Puzzling Actor-Critic vs. Actor plus Critic results

Hi,
I’m experimenting with networks and Deep Learning quite some time. Recently I had an observation which really strikes me:
I was trying to optimize the reinforce method to actor-critic. Doing that, there have been two different methods of implementing actor-critic.

  1. Using actor-critic with two separate networks, one for actor, one for critic.
  2. Using actor-critic with one network, where policy head and value haed are separated but share a common body.

From the network complexity, both implementations are identical.
However, from the performance there is a substantial difference! The first solution, using independent actor and critic is way better than the combined network. Interestingly, almost all actor-critic implementations available in the literature apply the latter and (seamingly) less performant solution.

So my question is, simple - what is the reason for that difference ?..or am I wrong?

For simplicity I’m using cartpole in the example below.

The first networks look like:

class Actor(nn.Module):
    def __init__(self, input_size, n_actions):
        super(Actor, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        z = self.net(x)
        distribution = Categorical(F.softmax(z, dim=-1))
        return distribution

class Critic(nn.Module):
    def __init__(self, input_size, n_actions):
        super(Critic, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):        
        return self.net(x)

The second networks look like:

class Actor_Critic(nn.Module):
    def __init__(self, input_size, n_actions):
        super(Actor_Critic, self).__init__()

        self.net_base = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU()   
        )

        self.net_head_actor =  nn.Sequential(
            nn.Linear(128, n_actions)
        )
        
        self.net_head_critic = nn.Sequential(
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        a = self.net_base(x)
        b = self.net_head_actor(a)
        c = self.net_head_critic(a)
        distribution = Categorical(F.softmax(b, dim=-1))
        return distribution, c

The full code to run these networks is shown below. But first some results.
Cartpole is considered solved when the last 50 episodes have on average 195 reward. A measurement how many episodes are required to train the netwerk is a good proxy for the training speed. As training depends on the initial network and environment setting, each network is tested 10 times and the output printed.

For the separate Actor Critic networks the results are as follows:
Results for Actor_Critic_separate :
Solved in 171 Episodes, with 23410 total Iterations; Mean Score 195.18
Solved in 288 Episodes, with 39820 total Iterations; Mean Score 195.52
Solved in 146 Episodes, with 15855 total Iterations; Mean Score 195.3
Solved in 166 Episodes, with 20582 total Iterations; Mean Score 195.78
Solved in 272 Episodes, with 13834 total Iterations; Mean Score 195.3
Solved in 260 Episodes, with 24027 total Iterations; Mean Score 195.24
Solved in 240 Episodes, with 36247 total Iterations; Mean Score 195.02
Solved in 168 Episodes, with 22186 total Iterations; Mean Score 196.26
Solved in 147 Episodes, with 19528 total Iterations; Mean Score 195.34
Solved in 161 Episodes, with 16442 total Iterations; Mean Score 195.56
Statistics for Actor_Critic_separate
-> Mean number of Episodes: 201.9, Std.Deviation: 53.253075028584036, Min number of Episodes: 146

For the combined Actor Critic network the result are these:
Results for Actor_Critic_combined :
Solved in 732 Episodes, with 74549 total Iterations; Mean Score 195.48
Solved in 936 Episodes, with 41875 total Iterations; Mean Score 195.16
NOT Solved in 1001 Episodes, with 106593 total Iterations; Mean Score 164.34
Solved in 423 Episodes, with 54609 total Iterations; Mean Score 196.02
Solved in 724 Episodes, with 94881 total Iterations; Mean Score 195.3
Solved in 780 Episodes, with 100322 total Iterations; Mean Score 196.54
Solved in 501 Episodes, with 61482 total Iterations; Mean Score 195.8
NOT Solved in 1001 Episodes, with 92636 total Iterations; Mean Score 188.28
Solved in 602 Episodes, with 49787 total Iterations; Mean Score 195.64
Solved in 679 Episodes, with 75490 total Iterations; Mean Score 195.5
Statistics for Actor_Critic_combined
-> Mean number of Episodes: 737.9, Std.Deviation: 188.92032712230835, Min number of Episodes: 423

There is a substantial difference in performance, which is consistend also with higher or lower learning rates (0.1…0.0001 tested). On average, the separated network needs only 202 episodes to finish while the combined network need more than 700 having also much higher variation.

So I’m really wondering what needs to be changed to make the combined network as efficient or even better than the separate networks. I spend hours, but I have no ideas left…

So any help or idea is really appreciated!

The code to glue everything together is shown here, its pretty simple actor-critic…

import gym, os
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
env = gym.make("CartPole-v0")

state_size = env.observation_space.shape[0]
action_size = env.action_space.n


class Actor(nn.Module):
    def __init__(self, input_size, n_actions):
        super(Actor, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        z = self.net(x)
        distribution = Categorical(F.softmax(z, dim=-1))
        return distribution

class Critic(nn.Module):
    def __init__(self, input_size, n_actions):
        super(Critic, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):        
        return self.net(x)

class Actor_Critic(nn.Module):
    def __init__(self, input_size, n_actions):
        super(Actor_Critic, self).__init__()

        self.net_base = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU()   
        )

        self.net_head_actor =  nn.Sequential(
            nn.Linear(128, n_actions)
        )
        
        self.net_head_critic = nn.Sequential(
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        a = self.net_base(x)
        b = self.net_head_actor(a)
        c = self.net_head_critic(a)
        distribution = Categorical(F.softmax(b, dim=-1))
        return distribution, c

    

def compute_returns(rewards, gamma=0.99):
    R = 0
    returns = []
    for step in reversed(range(len(rewards))):
        R = rewards[step] + gamma * R 
        returns.insert(0, R)
    return returns


def trainIters(actor, critic, actor_critic, AC_approach, learn_rate, n_iters):
    optimizerA = optim.Adam(actor.parameters(), lr=learn_rate)
    optimizerC = optim.Adam(critic.parameters(), lr=learn_rate)
    optimizerAC = optim.Adam(actor_critic.parameters(), lr=learn_rate)
 
    Episode = 0
    total_score = []
    total_count = 0
    while True:
        Episode += 1
        
        
        state = env.reset()
        log_probs = []
        values = []
        rewards = []
        
        env.reset()
       
        for i in count():            
            total_count += 1
            state = torch.FloatTensor(state).to(device)
    
            if AC_approach == "Actor_Critic_separate":
                dist, value  = actor(state), critic(state)      
            if AC_approach == "Actor_Critic_combined":  
                dist, value  = actor_critic(state)

            action = dist.sample() # das entspricht np.random.choice(len(dist), p=dist)        
            
            next_state, reward, done, _ = env.step(action.cpu().numpy())

            log_prob = dist.log_prob(action).unsqueeze(0)           

            log_probs.append(log_prob)

            values.append(value)
            rewards.append(torch.tensor([reward], dtype=torch.float, device=device))
            state = next_state

            if done:
                total_score.append(i+1)
#                if Episode % 10 == 0:
#                    print('Iteration: {}, Score: {}, Mean of last 50 Episodes: {}'.format(Episode, i+1, np.mean(total_score[-50:])))
                break


        returns = compute_returns(rewards)
        returns = torch.cat(returns).detach()
        
        log_probs = torch.cat(log_probs)
        
        values = torch.cat(values)
        advantage = returns - values
        
        actor_loss = -(log_probs * advantage.detach()).mean()
        critic_loss = advantage.pow(2).mean()
        actor_critic_loss = actor_loss + critic_loss
        
        if AC_approach == "Actor_Critic_separate":        
            optimizerA.zero_grad()
            optimizerC.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            optimizerA.step()
            optimizerC.step()
               
        if AC_approach == "Actor_Critic_combined":             
            optimizerAC.zero_grad()
            actor_critic_loss.backward()
            optimizerAC.step()
        
        if np.mean(total_score[-50:]) > 195:
            print("Solved in {} Episodes, with {} total Iterations; Mean Score {}"
                  .format(Episode, total_count, np.mean(total_score[-50:])))
            return Episode
            break
            
        if Episode > 1000:
            print("NOT Solved in {} Episodes, with {} total Iterations; Mean Score {}"
                  .format(Episode, total_count, np.mean(total_score[-50:])))
            return Episode
            break

    env.close()


if __name__ == '__main__':
    actor = Actor(state_size, action_size).to(device)
    critic = Critic(state_size, action_size).to(device)
    actor_critic = Actor_Critic(state_size, action_size).to(device)

    number_iteration = 10
  
    learn_rate = 0.01  

    Episode_count = []
    AC_approach="Actor_Critic_separate"
    print("Results for",AC_approach,":")
    for ix in range(number_iteration):
        actor = Actor(state_size, action_size).to(device)
        critic = Critic(state_size, action_size).to(device)
        actor_critic = Actor_Critic(state_size, action_size).to(device)            
        e = trainIters(actor, critic, actor_critic, AC_approach, learn_rate, n_iters=1000)
        Episode_count.append(e)

    print("Statistics for ",AC_approach,"\n-> Mean number of Episodes: {}, Std.Deviation: {}, Min number of Episodes: {}"
          .format(np.mean(Episode_count),np.std(Episode_count),min(Episode_count) ))
    print()
    
    Episode_count = []
    AC_approach="Actor_Critic_combined"
    print("Results for ",AC_approach,":")
    for ix in range(number_iteration):
        actor = Actor(state_size, action_size).to(device)
        critic = Critic(state_size, action_size).to(device)
        actor_critic = Actor_Critic(state_size, action_size).to(device)
        e = trainIters(actor, critic, actor_critic, AC_approach, learn_rate, n_iters=1000)
        Episode_count.append(e)
        
    print("Statistics for ",AC_approach,"\n-> Mean number of Episodes: {}, Std.Deviation: {}, Min number of Episodes: {}"
          .format(np.mean(Episode_count),np.std(Episode_count),min(Episode_count) ))