Replacing Q lookup table with neural network

From reading tutorial https://towardsdatascience.com/reinforcement-learning-temporal-difference-sarsa-q-learning-expected-sarsa-on-python-9fecfda7467e the following code successfully trains an RL algorithm to make decisions in the ‘Taxi-v3’ OpenAi gym environment.

import gym
import numpy as np
import time

"""
SARSA on policy learning python implementation.
This is a python implementation of the SARSA algorithm in the Sutton and Barto's book on
RL. It's called SARSA because - (state, action, reward, state, action). The only difference
between SARSA and Qlearning is that SARSA takes the next action based on the current policy
while qlearning takes the action with maximum utility of next state.
Using the simplest gym environment for brevity: https://gym.openai.com/envs/FrozenLake-v0/
"""

def init_q(s, a, type="ones"):
    """
    @param s the number of states
    @param a the number of actions
    @param type random, ones or zeros for the initialization
    """
    if type == "ones":
        return np.ones((s, a))
    elif type == "random":
        return np.random.random((s, a))
    elif type == "zeros":
        return np.zeros((s, a))


def epsilon_greedy(Q, epsilon, n_actions, s, train=False):
    """
    @param Q Q values state x action -> value
    @param epsilon for exploration
    @param s number of states
    @param train if true then no random actions selected
    """
    if train or np.random.rand() < epsilon:
        action = np.argmax(Q[s, :])
    else:
        action = np.random.randint(0, n_actions)
    return action

def sarsa(alpha, gamma, epsilon, episodes, max_steps, n_tests, render = True, test=False):
    """
    @param alpha learning rate
    @param gamma decay factor
    @param epsilon for exploration
    @param max_steps for max step in each episode
    @param n_tests number of test episodes
    """
    env = gym.make('Taxi-v3')
    n_states, n_actions = env.observation_space.n, env.action_space.n
    print('n_states' ,n_states)
    Q = init_q(n_states, n_actions, type="ones")
    print('Q shape:' , Q.shape)
    
    timestep_reward = []
    for episode in range(episodes):
        print(f"Episode: {episode}")
        total_reward = 0
        s = env.reset()
        print('s:' , s)
        a = epsilon_greedy(Q, epsilon, n_actions, s)
        t = 0
        done = False
        while t < max_steps:
            if render:
                env.render()
            t += 1
            s_, reward, done, info = env.step(a)
            print('state is' , s)
            total_reward += reward
            a_ = epsilon_greedy(Q, epsilon, n_actions, s_)
            if done:
                Q[s, a] += alpha * ( reward  - Q[s, a] )
            else:
                Q[s, a] += alpha * ( reward + (gamma * Q[s_, a_] ) - Q[s, a] )
            s, a = s_, a_
            if done:
                if render:
                    print(f"This episode took {t} timesteps and reward {total_reward}")
                timestep_reward.append(total_reward)
                break
#             print('Updated Q values:' , Q)
    if render:
        print(f"Here are the Q values:\n{Q}\nTesting now:")
    if test:
        test_agent(Q, env, n_tests, n_actions)
    return timestep_reward

def test_agent(Q, env, n_tests, n_actions, delay=0.1):
    for test in range(n_tests):
        print(f"Test #{test}")
        s = env.reset()
        done = False
        epsilon = 0
        total_reward = 0
        while True:
            time.sleep(delay)
            env.render()
            a = epsilon_greedy(Q, epsilon, n_actions, s, train=True)
            print(f"Chose action {a} for state {s}")
            s, reward, done, info = env.step(a)
            total_reward += reward
            if done:  
                print(f"Episode reward: {total_reward}")
                time.sleep(1)
                break


if __name__ =="__main__":
    alpha = 0.4
    gamma = 0.999
    epsilon = 0.9
    episodes = 200
    max_steps = 20
    n_tests = 20
    timestep_reward = sarsa(alpha, gamma, epsilon, episodes, max_steps, n_tests)
    print(timestep_reward)

I’m attempting to modify this code to use a Deep Q Learning instead of Q Learning.
From reading https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf
it mentions using a separate target network which I think is translates to swapping the Q lookup table to
a neural network function approximator . To try to keep things simple as possible I’m not planning to
utilize replay memory for this initial solution.

From paper ‘Combining Q-Learning with Artificial Neural Networks in an Adaptive Light Seeking Robot’
(src:m https://pdfs.semanticscholar.org/79db/40a28420ebd2d108a1401db195dd37e9aefd.pdf) it states when replacing Q table lookup with neural network:
“Obtain Q(x, a) for each action by substituting the state and action pairs into the neural net, keeping
track of those values.” & “GenerateQtarget(x,a) according to equation1 and use Q target to train the net
as shown in fig 8 below.”

Does this mean swapping

Q[s, a] += alpha * ( reward + (gamma * Q[s_, a_] ) - Q[s, a] )

with a neural network where the network inputs are state (s) and the output values are a (action) ?

Hi Adrian,

I think it’s a typo but you are missing a max for Q[s_, a_] values, since you need to find state-action pair with the maximum value for all actions.

The neural network works as a function approximator here, so instead of looking up a table you can use the network to find Q values for all actions in that state. When you predict those values (by inputing state to network), i.e:
a_vector_of_Q_values_for_all_actions_in_s__ = model(s_)
you can find the max of them to find Y = reward + gamma * maxQ(s_, a_)

Now if X = model(s) then backpropagating loss(X[a],Y) constitutes the learning part of the algorithm which replaces the Q-value update in tabular Q-learning.

I also found this post very useful to understand the transition from tabular Q-learning to DQN: https://towardsdatascience.com/why-going-from-implementing-q-learning-to-deep-q-learning-can-be-difficult-36e7ea1648af