DQN Snake not making great progress

Hi everyone,

I am trying to train a DQN for a multiplayer snake version. I’ve tried multiple versions, but none of them seem to work well so I assume that there is an error in my code. Could somebody check please?

This is a dueling DQN. The network is a CNN which receives a 10(number of frames)711(Board size) input. Frames 1-5 are the last board position (Myself, Opponent1, Opponent2,Opponent3, Food) and Frames 6-10 are the current board position. This is important because the opposite of the last move (e.g. North when South) is forbidden. Reward is defined as +1 at every step , -100 when dying and +10 when eating food. Is there an error somewhere?

from kaggle_environments import evaluate, make

from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col, random_agent, greedy_agent

import numpy as np
import torch
import torch.nn as nn
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
#class to handle the storage of past transitions

class Replay_buffer():
def init(self,capacity):
self.capacity = capacity
self.pos = 0
self.memory = list()
self.action_size = 1
self.reward_size = 1

def add(self,element):
    element['state'] = torch.Tensor(element['state'])
    element['action'] = torch.LongTensor([element['action']])
    element['new_state'] = torch.Tensor(element['new_state'])
    element['reward'] = torch.Tensor([element['reward']])
    if len(self.memory) < self.capacity:
        self.memory.append(element)
    else:
        self.memory[self.pos] = element
        self.pos = (self.pos+1) %self.capacity
def get_item(self,batch_size):
    return random.sample(self.memory,batch_size)

class Q_network(nn.Module):
def init(self):
super(Q_network,self).init()
self.input = nn.Sequential(
nn.Conv2d(10, 24, kernel_size=2, stride = 2,padding = 1),
nn.BatchNorm2d(24),
nn.ReLU()

    )

    

    self.hidden = nn.Sequential(
        nn.Conv2d(24, 32, kernel_size=2, stride=2 , padding = 1),
        nn.BatchNorm2d(32),
        nn.ReLU()

    )

    self.out_value = nn.Sequential(
        nn.Linear(384,128),
        nn.ReLU(),
        nn.Linear(128,1)

    )

    self.out_advantage = nn.Sequential(
        nn.Linear(384,128),
        nn.ReLU(),
        nn.Linear(128,4)

    )

def forward(self,x):
    x = self.input(x)
    x = self.hidden(x)
    x = torch.flatten(x,start_dim=1)
    x_v = self.out_value(x)
    x_a = self.out_advantage(x)
    x = x_v + (x_a-x_a.mean())
    return x

class DQN(nn.Module):

def __init__(self):
    super(DQN,self).__init__()
    self.q_network = Q_network_linear()
    self.loss_fn = None
    self.optimizer = None

def forward(self,x):
    x = self.q_network(x)
    return x

def state_input_to_features(state):
layers = np.zeros((5,77))
for i,goose in enumerate(state.geese):
if goose:
layers[i][goose[0]] = 1
if len(goose) > 1:
for pos in goose[1:]:
layers[i][pos] = 0.5
for food in state.food:
layers[4][food] = 1
layers = layers.reshape(5,7,11)
return layers

def select_action(network,state,epsilon):
if np.random.random() < epsilon:
return np.random.randint(0,4)
else:
with torch.no_grad():
network.eval()
tmp = network(state)
tmp = tmp.max(1)
return tmp[1].view(-1,1)

def training_step(network,target_network, replay_buffer):
network.train()
x = replay_buffer.get_item(network.batch_size)
network.optimizer.zero_grad()
s = torch.stack([tmp[‘state’] for tmp in x])
a = torch.stack([tmp[‘action’] for tmp in x])
s_new = torch.stack([tmp[‘new_state’] for tmp in x])
r = torch.stack([tmp[‘reward’] for tmp in x])
done = [tmp[‘done’] for tmp in x]
output = network(s.view(network.batch_size,10,7,11))
predicted_values = output.gather(1,a)
with torch.no_grad():
s_new = s_new.view(network.batch_size,10,7,11)
output_target = target_network(s_new).gather(1,select_action(network,s_new,0))
expected_values = output_target.view(network.batch_size,1)*0.99+r
for i,d in enumerate(done):
#no further state reward when done
if d:
expected_values[i] = r[i]
loss = network.loss_fn(predicted_values,expected_values)
loss.backward()
for param in network.parameters():
param.grad.data.clamp_(-1,1)
network.optimizer.step()
return loss.item()

def get_reward(last_observation,current_observation):
if not current_observation.geese[0]:
return -100
elif not last_observation:
if len(current_observation.geese[0]) > 1:
return 10
else:
return 1
elif len(current_observation.geese[0]) > len(last_observation.geese[0]):
return 10
else:
return 1

def train_agent(path,epochs):
env = make(‘hungry_geese’, debug = True)
training = env.train([None,“greedy”,“greedy”,“greedy”])
obs = training
replay_buffer = Replay_buffer(100000)
if path:
network = load_network(path)
target_network = load_network(path)
else:
network = DQN()
target_network = DQN()
network.loss_fn = torch.nn.SmoothL1Loss()
network.optimizer = torch.optim.RMSprop(network.parameters(), lr = 1e-2)
network.batch_size = 32
state = env.reset(num_agents = 4)[0].observation
losses = list()
episode_returns = list()
episode_return = 0
last_observation = None
last_position = torch.zeros(5,7,11)
for iteration in range(epochs):
env.render()
current_state = torch.Tensor(state_input_to_features(state))
s = torch.cat((last_position,current_state),dim=0)
action = select_action(network,s.view(10,7,11).unsqueeze(dim=0),0.2)
if type(action) != int:
action = action.item()
obs,reward,done,info = training.step(action_dict[action])
reward = get_reward(last_observation,obs)
last_observation = obs
last_position = current_state
s_2 = torch.cat((current_state,torch.Tensor(state_input_to_features(obs))),dim=0)
replay_buffer.add({‘state’: s, ‘action’: action, ‘new_state’: s_2 , ‘reward’: reward, ‘done’: done})
if iteration >= network.batch_size:
losses.append(training_step(network,target_network,replay_buffer))
if iteration % 1000 == 0:
target_network.load_state_dict(network.state_dict())
plt.plot(range(len(losses)),losses, label = ‘Loss’)
plt.legend()
plt.show()
plt.plot(range(len(episode_returns)),episode_returns, label = ‘Episodes’)
plt.legend()
plt.show()
state = obs
episode_return += reward
if done:
state = env.reset(num_agents = 4)[0].observation
episode_returns.append(episode_return)
episode_return = 0
last_observation = None
last_position = torch.zeros(5,7,11)
torch.save(network.state_dict(),‘soft_goose.pt’)
plt.plot(range(len(losses)),losses, label = ‘Loss’)
plt.legend()
plt.show()

plt.plot(range(len(episode_returns)),episode_returns, label = 'Episodes')
plt.legend()
plt.show()

def load_network(path):
network = DQN()
network.load_state_dict(torch.load(path, map_location=torch.device(‘cpu’)))
network = network.eval()
return network