How to implement simple LSTM in reinforcement task ('CartPole-v0')

I’ve realised I don’t understand LSTMs in Pytorch quite as well as I thought, so I’m adapting the CartPole demo from Soumith Chintala to give myself a simple challenge of switching the main Linear layer with an LSTM.

The example here fails on the first pass with:
RuntimeError: Input batch size 1 doesn’t match hidden[0] batch size 128
on
x, self.hidden = self.lstm(x, self.hidden)

Now if I change this line to
x, _ = self.lstm(x, self.hidden)

… it converges and completes the task in a reasonable 950 Episodes.

However, because it’s not feeding the Hidden states back if I do this (forgive my improvised terminology), it’s presumably not really taking advantage of the capabilities of the LSTM? I can’t quite sus the error either, as the self.hidden I’m feeding it is 1,128, and the self.hidden it’s outputting seems to be 1,128.

I notice more advanced implementations of A3C models to play Atari games tend to use the hidden.values – up until now, I’ve really not used this aspect, generally using them x = self.lstm(x) … Which presumably is fine so long as all your work’s being done within batches?

Thanks so much for any help. I’d also be interested in whether an LSTM (rather than LSTMCell) might be utilised for this problem? But I’m having separate problems getting the data the right shape.

import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical


GAMMA = 0.99
env = gym.make('CartPole-v0')
env.seed(1)
torch.manual_seed(1)

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.lstm = nn.LSTMCell(3, 128)
        self.action_head = nn.Linear(128, 2)
        self.value_head = nn.Linear(128, 1)
        self.saved_actions = []
        self.rewards = []
        self.hidden = None


    def forward(self, x):
        x = x.unsqueeze(0)
        x, self.hidden = self.lstm(x, self.hidden)
        x = x.squeeze(0)
        action_scores = self.action_head(x)
        state_values = self.value_head(x)
        return F.softmax(action_scores, dim=-1), state_values


model = Policy()
optimizer = optim.Adam(model.parameters(), lr=0.001)
eps = np.finfo(np.float32).eps.item()


def select_action(state):
    state = torch.from_numpy(state).float()
    probs, state_value = model(state.narrow(0,1,3))
    m = Categorical(probs)
    action = m.sample()
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.item()


def finish_episode():
    R = 0
    saved_actions = model.saved_actions
    policy_losses = []
    value_losses = []
    rewards = []
    for r in model.rewards[::-1]:
        R = r + GAMMA * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for (log_prob, value), r in zip(saved_actions, rewards):
        reward = r - value.item()
        policy_losses.append(-log_prob * reward)
        value_losses.append(F.smooth_l1_loss(value, torch.tensor([r])))
    optimizer.zero_grad()
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
    loss.backward()
    optimizer.step()
    del model.rewards[:]
    del model.saved_actions[:]


def main():
    running_reward = 10
    for i_episode in count(1):
        state = env.reset()
        for t in range(10000):
            action = select_action(state)
            state, reward, done, _ = env.step(action)
            model.rewards.append(reward)
            if done:
                break

        running_reward = running_reward * 0.99 + t * 0.01
        finish_episode()
        if i_episode % 10 == 0:
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
                i_episode, t, running_reward))
        if running_reward > env.spec.reward_threshold:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))
            for t in range(100000):
                action = select_action(state)
                state, reward, done, _ = env.step(action)
                env.render()
                model.rewards.append(reward)
            # if done:
            #     break
            break

if __name__ == '__main__':
    main()
    env.env.close()

Okay, I think I got it. This is the only way I learn.

So I reset the Hidden states every time it’s done, and otherwise .detach() them. And it takes two or three times longer to train, and does a pretty horrible job – but presumably LSTMs aren’t very suited to this task.

EDIT: one more update – completes the task around 1680 episodes on v1, a somewhat human-looking model here, with a natural wobble and some fairly dynamic rescues … unlike the machine-precision of usual solutions.

import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.autograd import Variable


GAMMA = 0.99
HIDDEN_SIZE = 64
env = gym.make('CartPole-v1')
env.seed(1)
torch.manual_seed(1)

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.lstm = nn.LSTMCell(4, HIDDEN_SIZE)
        # self.affine = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
        self.action_head = nn.Linear(HIDDEN_SIZE, 2)
        self.value_head = nn.Linear(HIDDEN_SIZE, 1)
        self.saved_actions = []
        self.rewards = []
        self.reset()
        
    def reset(self):
        self.hidden = Variable(torch.zeros(1, HIDDEN_SIZE)), Variable(torch.zeros(1, HIDDEN_SIZE))

    def detach_weights(self):
        self.hidden = self.hidden[0].detach(), self.hidden[1].detach()

    def forward(self, x):
        x = x.unsqueeze(0)
        self.hidden = self.lstm(x, self.hidden)
        x = self.hidden[0]
        x = x.squeeze(0)
        # x = self.affine(x)
        action_scores = self.action_head(x)
        state_values = self.value_head(x)
        return F.softmax(action_scores, dim=-1), state_values


model = Policy()
optimizer = optim.Adam(model.parameters(), lr=0.001)
eps = np.finfo(np.float32).eps.item()


def select_action(state):
    state = torch.from_numpy(state).float()
    probs, state_value = model(state)
    m = Categorical(probs)
    action = m.sample()
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.item()


def finish_episode():
    R = 0
    saved_actions = model.saved_actions
    policy_losses = []
    value_losses = []
    rewards = []
    for r in model.rewards[::-1]:
        R = r + GAMMA * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for (log_prob, value), r in zip(saved_actions, rewards):
        reward = r - value.item()
        policy_losses.append(-log_prob * reward)
        value_losses.append(F.smooth_l1_loss(value, torch.tensor([r])))
    optimizer.zero_grad()
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
    loss.backward()
    optimizer.step()
    del model.rewards[:]
    del model.saved_actions[:]


def main():
    running_reward = 10
    for i_episode in count(1):
        state = env.reset()
        for t in range(10000):
            action = select_action(state)
            state, reward, done, _ = env.step(action)
            model.rewards.append(reward)
            if done:
                model.reset()
                break
            else:
                model.detach_weights()

        running_reward = running_reward * 0.99 + t * 0.01
        finish_episode()
        if i_episode % 10 == 0:
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
                i_episode, t, running_reward))
        if running_reward > env.spec.reward_threshold:
            model.reset()
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))
            for t in range(100000):
                action = select_action(state)
                state, reward, done, _ = env.step(action)
                env.render()
                model.rewards.append(reward)
                # if done:
                #     break
            break

if __name__ == '__main__':
    main()
    env.env.close()

2 Likes