I am trying to build a DQN to play tic tac toe but I just cannot get it to convergence. I tried both convolutional and fully connected feed forward networks but I just cant get the network to master the game. I measure the performance by letting the agent play against a random player and check the win percentage of the random player. It should converge to zero, but I cannot get it below 30%. Can someone maybe check the code, if I implemented something incorrectly? I would be so thankful, I don’t know how many hours I spent on hyperparameter search now…
Here is my code:
import numpy as np
import random
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn
torch.manual_seed(0)
# returns new state
def play_move(st, player, move):
state_copy = st.copy()
while state_copy[move] != 0:
move = int(input('This space is already occupied! Choose again: ')) - 1
state_copy[move] = player
return state_copy
# returns winner, game status (Done, Not Done, Draw)
def check_result(st):
# check horizontal
if st[0] == st[1] == st[2] and st[0] != 0:
return st[0], 'Done'
if st[3] == st[4] == st[5] and st[3] != 0:
return st[3], 'Done'
if st[6] == st[7] == st[8] and st[6] != 0:
return st[6], 'Done'
# check vertical
if st[0] == st[3] == st[6] and st[0] != 0:
return st[0], 'Done'
if st[1] == st[4] == st[7] and st[1] != 0:
return st[1], 'Done'
if st[2] == st[5] == st[8] and st[2] != 0:
return st[2], 'Done'
# check diagonal
if st[0] == st[4] == st[8] and st[0] != 0:
return st[0], 'Done'
if st[2] == st[4] == st[6] and st[2] != 0:
return st[2], 'Done'
# check for draw
if st.count(0) == 0:
return None, 'Draw'
# otherwise
return None, 'Not Done'
# prints board
def print_board(st):
state_copy = st.copy()
for i , symbol in enumerate(state_copy):
if symbol == 0:
state_copy[i] = ' '
elif symbol == 1:
state_copy[i] = 'X'
elif symbol == -1:
state_copy[i] = 'O'
print(' ' + str(state_copy[0]) + ' | ' + str(state_copy[1]) + ' | ' + str(state_copy[2]) + ' ')
print('-------------')
print(' ' + str(state_copy[3]) + ' | ' + str(state_copy[4]) + ' | ' + str(state_copy[5]) + ' ')
print('-------------')
print(' ' + str(state_copy[6]) + ' | ' + str(state_copy[7]) + ' | ' + str(state_copy[8]) + ' ')
class CNN(nn.Module):
def __init__(self, kernel):
super(CNN, self).__init__()
self.process_cnn = nn.Sequential(
# transforms to 3x3 (formula: out = (in-k+2*p)/s + 1)
nn.Conv2d(1, kernel, 3, stride=1, padding=1),
nn.ReLU(True),
# transforms to 3x3
nn.Conv2d(kernel, kernel, 3, stride=1, padding=1),
nn.ReLU(True)
)
self.process_lin = nn.Sequential(
nn.Linear(3 * 3 * kernel, 9)
)
def process(self, x):
x = x.view(-1, 1, 3, 3)
# Apply convolutions
x = self.process_cnn(x)
# Flatten
x = x.view([x.size(0), -1])
# Apply linear layers
x = self.process_lin(x)
return x
def forward(self, x):
x = self.process(x)
return x
class FF(nn.Module):
def __init__(self, Nh1, Nh2, Nh3):
super(FF, self).__init__()
self.fc = nn.Sequential(
nn.Linear(9, Nh1),
nn.ReLU(True),
nn.Linear(Nh1 ,Nh2),
nn.ReLU(True),
nn.Linear(Nh2, Nh3),
nn.ReLU(True),
nn.Linear(Nh3, 9)
)
def forward(self, x):
x = x.view(-1, 9)
x = self.fc(x)
return x
# Initialize network
Nh1 = 16
Nh2 = 32
Nh3 = 16
# net = FF(Nh1, Nh2, Nh3)
# target_net = FF(Nh1, Nh2, Nh3)
kernel = 4
net = CNN(kernel)
target_net = pickle.loads(pickle.dumps(net))
# Define Loss Function
loss_fn = nn.MSELoss()
def select_action(st, eps):
allowed_actions = []
for i, cell in enumerate(st):
if cell == 0:
allowed_actions.append(i)
# Pass state through NN
input = torch.tensor(st).float()
output = net(input).float().squeeze().detach().numpy()
# Get Q values of legal moves
qvalues = []
for action in allowed_actions:
qvalues.append(output[action])
# Implement behavior policy
if len(allowed_actions) > 1:
# assign equal value to all other allowed moves
prob = np.ones(len(allowed_actions)) * eps / (len(allowed_actions) - 1)
prob[np.argmax(qvalues)] = 1 - eps
else:
prob = [1]
chosen_action = allowed_actions[
np.random.choice(range(0, len(allowed_actions)), p=prob)] # between all action(range(0,9) choose one)
return chosen_action
def get_pred_and_target(st, next_state, act, player, discount):
pred = torch.tensor([net(torch.tensor(st).float()).squeeze().detach().numpy()[act]])
# Define reward
reward = 0.
winner, game_status = check_result(next_state)
if game_status == 'Done' and winner == player:
reward = 1.
if game_status == 'Done' and winner != player:
reward = -1.
if game_status == 'Draw':
reward = 0.
# Define target
if next_state.count(0) == 0:
target = torch.tensor([reward], requires_grad=True).float()
else:
target = torch.tensor([reward]).float() + discount * torch.max(
target_net(torch.tensor(st).float()))
return pred, target
def select_random_action(state):
# Create array with allowed moves
allowed_actions = []
for i, cell in enumerate(state):
if cell == 0:
allowed_actions.append(i)
chosen_action = allowed_actions[np.random.choice(range(0, len(allowed_actions)))]
return chosen_action
# Training against random player
num_epochs = 10000
epsilon_array = np.linspace(0.3, 0.01, num_epochs) # epsilon decays with every epoch
lr_array = np.linspace(0.000001, .0000001, num_epochs)
results = []
percentages = []
preds = torch.tensor([0]).float()
targets = torch.tensor([0]).float()
training = True
playing = False
batch_size = 10
update_target = 100
if training:
for epoch in range(num_epochs):
lr = lr_array[epoch]
# Define Optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr_array[epoch], weight_decay=1e-5)
# Produce batch
for i in range(batch_size):
# Clear Board
state = [0, 0, 0, 0, 0, 0, 0, 0, 0]
epsilon = epsilon_array[epoch]
game_status = 'Not Done'
winner = None
players_turn = random.choice([0, 1])
while game_status == 'Not Done':
if players_turn == 0: # X's move
# print("\nAI X's turn!")
action = select_action(state, epsilon)
new_state = play_move(state, 1, action)
else: # O's move
# print("\nAI O's turn!")
action = select_random_action(state)
new_state = play_move(state, -1, action)
# get pred and target for Q(s,a)
pred, target = get_pred_and_target(state, new_state, action, 1, discount=0.99)
# update batch
preds = torch.cat([preds, pred])
targets = torch.cat([targets, target])
# update state
state = new_state.copy()
# print_board(new_state)
winner, game_status = check_result(state)
if winner is not None:
# print(str(winner) + ' won!')
if winner == 1:
results.append('X')
else:
results.append('O')
else:
players_turn = (players_turn + 1) % 2
if game_status == 'Draw':
# print('Draw!')
results.append('Draw')
loss = loss_fn(preds, targets)
optimizer.zero_grad()
# Backward pass
loss.backward()
# Update
optimizer.step()
# Clear batch
preds = torch.tensor([0]).float()
targets = torch.tensor([0]).float()
# update target net
if epoch % update_target == 0:
print('Epoch: ' + str(epoch))
print(torch.mean(loss))
print(pred)
target_net = pickle.loads(pickle.dumps(net))
percentage = results[-700:].count('O')/7
percentages.append(percentage)
print(f'Random player win percentage in the last 700 games: {percentage} %')
print('Training Complete')
# Plot draw percentage curve
x = np.linspace(1, len(percentages), len(percentages))
plt.plot(x, percentages)
plt.show()