Actor Critic Loss explodes

Hello I have some problems with implementing the Actor Critic Policy Gradient Algorithm,

When I implement REINFORCE like this, everything is okay:

class Estimator(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.num_actions = num_actions
        self.dense_1 = nn.Linear(4, 32)
        self.out = nn.Linear(32, num_actions)

    def forward(self,x):
        x =self.dense_1(x)
        x = F.softmax(self.out(x))
        return x
    
env = gym.make("CartPole-v0")
estimator = Estimator(2)
estimator.cuda()
opt = optim.Adam(estimator.parameters())
loss = []
running_reward = 0
for i in range(100000): # number episodes
    episode = []
    chosen_actions = []
    rewards = []
    done = False
    state = env.reset()
                  
    while not done:
        probs = estimator(Variable(torch.unsqueeze(torch.from_numpy(state),0).float().cuda())) # calculate the probs of choosing actions
        action = probs.multinomial()
        chosen_actions.append(action)
        next_state, reward, done, _ = env.step(action.data[0,0])
        rewards.append(reward)
        state = next_state
       
    
    R = 0
    for r in rewards[::-1]:
        R = r +  R
        rewards.insert(0, R)
        
    for action, r in zip(chosen_actions, rewards):
        action.reinforce(r)
        
    opt.zero_grad()
    autograd.backward(chosen_actions, [None for _ in chosen_actions])
    opt.step()
    running_reward = running_reward * 0.99 + len(chosen_actions) * 0.01
    if (i+1) % 10 == 0:
        print("Episode: {} Running Reward: {}".format(i+1,round(running_reward,2)))

But if I try to implement the actor critic with the using the Generalized Advantage Estimator the Algorithm fails.

Two things are happening:

  1. The policy is learning to ALWAYS choose one off the actions (the probability is approaching 1)
  2. the loss (and the output) of the state value estimator explode.

I checked the implementation of the policy gradient by pretraining a state value estimator using TD and then plugging it into the code. That works just fine, so I suspect I have made some error implementing the state value updates.

Here is the code:

(I tried a lot of different hyperparamenters like learning rate of both optimizers, the number of state value estimator updates per timestep and also starting the policy gradient algorithm after a short amount of time where the state value predictions are already learning, unfortunatly nothing has worked…)

class Estimator(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.num_actions = num_actions
        self.dense_1 = nn.Linear(4, 32)
        self.out = nn.Linear(32, num_actions)

    def forward(self,x):
        x =self.dense_1(x)
        x = F.softmax(self.out(x))
        return x

class V_Estimator(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense_1 = nn.Linear(4, 32)
        self.out = nn.Linear(32, 1)

    def forward(self,x):
        x = F.relu(self.dense_1(x))
        x = self.out(x)
        return x
    
    
estimator = Estimator(2)
estimator.cuda()
v_estimator = V_Estimator()
v_estimator.cuda()
opt = optim.Adam(estimator.parameters(), lr=0.0001)
v_opt = optim.Adam(v_estimator.parameters(), lr=0.0001)
env = gym.make("CartPole-v0")
mse = nn.MSELoss()
buffer = ReplayBuffer(100000)
running_reward = 0
for i in range(10000): # number episodes
    episode_len = 0
    done = False
    state = env.reset()
    

    while not done:
        episode_len += 1
        state_python = state
        state = Variable(torch.unsqueeze(torch.from_numpy(state),0).float().cuda())
        probs = estimator(state) 
        #print(probs.data.cpu().numpy()) # one of the action probabilites just approaches 1
        action = probs.multinomial()

        action_python = action.data[0,0]
        v_estimate_curr = v_estimator(state)
        #v_estimate_curr = v_estimate(state)
        #print(v_estimate_curr)
        next_state, reward, done, _ = env.step(action_python)
        v_estimate_next = v_estimator(Variable(torch.unsqueeze(torch.from_numpy(next_state),0).float().cuda()))
        #v_estimate_next = v_estimate(Variable(torch.unsqueeze(torch.from_numpy(next_state),0).float().cuda()))
        #print(v_estimate_next)
        
        td_error = reward + v_estimate_next - v_estimate_curr
        
        
        buffer.add(state_python, action_python, reward, done, next_state)
        state = next_state
        
        #refit v-estimator
        average_state_value_loss = 0
        state_value_updates = 30
        for j in range(state_value_updates):
            s_batch, a_batch, r_batch, d_batch, s2_batch  = buffer.sample_batch(128)
            #print("s_batch shape: {}".format(s_batch.shape))
            targ = v_estimator(Variable(torch.from_numpy(s2_batch)).float().cuda())
            #print("targ shape: {}".format(targ.data.cpu().numpy().shape))
            torch_rew_batch = Variable(torch.unsqueeze(torch.from_numpy(r_batch).float().cuda(),-1))
            #print("torch_rew_batch shape: {}".format(torch_rew_batch.data.cpu().numpy().shape))
            targ = targ + torch_rew_batch
            targ = targ.detach()
            targ.requires_grad = False
            #print("targ shape: {}".format(targ.data.cpu().numpy().shape))
            out = v_estimator(Variable(torch.from_numpy(s_batch)).float().cuda())
            #print("out shape: {}".format(out.data.cpu().numpy().shape))
            v_loss = mse(out, targ)
            average_state_value_loss += v_loss.data[0] / state_value_updates
            
            v_opt.zero_grad()
            v_loss.backward()
            v_opt.step()

        # update policy gradient
        #if i > 100: # starting after 100 episodes to give the state value nn some time to learn
        opt.zero_grad()
        action.reinforce(td_error.data)
        action.backward()
        opt.step()
    running_reward = running_reward * 0.9 + episode_len * 0.1
    print("current episode: " + str(i)+ " - running reward: " + str(round(running_reward,2)) + " - average state value estimator loss: {}".format(average_state_value_loss))

I looked at the implementation in the pytorch examples repo but they do things a little differently (like sharing policy parameters)

If anybody has any idea on how to fix the error I would greatly appreciate it

Johannes

PS: for reproducebility execute this first, then both code samples run:

import torch
from torch.autograd import Variable
import gym
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torch import autograd
from torch import optim

from collections import deque
import random
import numpy as np

class ReplayBuffer(object):

    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.count = 0
        self.buffer = deque()

    def add(self, s, a, r, d, s2):
        experience = (s, a, r, d, s2)
        if self.count < self.buffer_size:
            self.buffer.append(experience)
            self.count += 1
        else:
            self.buffer.popleft()
            self.buffer.append(experience)

    def size(self):
        return self.count

    def sample_batch(self, batch_size):
        '''
        batch_size specifies the number of experiences to add
        to the batch. If the replay buffer has less than batch_size
        elements, simply return all of the elements within the buffer.
        '''

        if self.count < batch_size:
            batch = random.sample(self.buffer, self.count)
        else:
            batch = random.sample(self.buffer, batch_size)

        s_batch = np.array([np.array(_[0]) for _ in batch])
        a_batch = np.array([_[1] for _ in batch])
        r_batch = np.array([_[2] for _ in batch])
        d_batch = np.array([_[3] for _ in batch])
        s2_batch = np.array([np.array(_[4]) for _ in batch])

        return s_batch, a_batch, r_batch, d_batch, s2_batch

    def clear(self):
        self.buffer.clear()
        self.count = 0

Curious if you managed to figure out what the reason was? I think I’m suffering from something similar…

Yes I think I solved it. I will dig up the code later today and post it.

2 Likes

Hey @j.laute was wondering if there were any updates on this. Thanks!

EDIT: just saw that the code sample uses mse loss, in that case im sorry no idea how I solved it, quite some time ago. Will post when I’m back but that is over a month away. Would recommend to search for some working pytorch implementation online and compare

Sorry, completely forgot to post the code. Unfortunately I am on summer break right now and have no access to the code.

The solution was in my case to use the mse loss instead of the (smooth) l1 loss (which was really unintuitive as the smooth l1 loss is explicitly recommend to prevent problems from the mse loss)

Never got round to actually find out why the l1 loss didn’t work for me, surely I made some implementation mistake.

Cheers, Johannes