I followed a tutorial for ddqn to beat pong, it beats it with a perfect score in keras, but trying to translate it to pytorch it doesn’t learn at all. What am I missing? I pasted all the code for each model in case I miss something subtle.
KERAS AGENT
import random
import gym
from agent_base import Agent
from agent_augments.memory import ReplayMemory
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten
from keras.optimizers import Adam
class PongAgent(Agent, ReplayMemory):
def __init__(self,
state_space,
action_space,
channels=4,
batch_size=64,
epsilon=1,
epsilon_min=0.01,
epsilon_decay=1e-5,
gamma=0.99,
learning_rate=0.0002,
save_freq=500,
training=True):
ReplayMemory.__init__(self, capacity=10000)
self.state_space = state_space
self.action_space = action_space
self.channels = channels
self.batch_size = batch_size
self.save_freq = save_freq
self.iteration_count = 0
self.training = training
self.replace = 1000
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.gamma = gamma
self.q_eval = Sequential([
Conv2D(filters=32, kernel_size=3, strides=2, activation='relu',
input_shape=state_space, data_format='channels_first'),
Conv2D(filters=64, kernel_size=5, strides=2, activation='relu',
input_shape=state_space, data_format='channels_first'),
Conv2D(filters=64, kernel_size=3, strides=1, activation='relu',
input_shape=state_space, data_format='channels_first'),
Flatten(),
Dense(512, activation='relu'),
Dense(action_space)]
)
self.q_eval.compile(optimizer=Adam(lr=learning_rate), loss='mean_squared_error')
self.q_next = Sequential([
Conv2D(filters=32, kernel_size=3, strides=2, activation='relu',
input_shape=state_space, data_format='channels_first'),
Conv2D(filters=64, kernel_size=5, strides=2, activation='relu',
input_shape=state_space, data_format='channels_first'),
Conv2D(filters=64, kernel_size=3, strides=1, activation='relu',
input_shape=state_space, data_format='channels_first'),
Flatten(),
Dense(512, activation='relu'),
Dense(action_space)]
)
self.q_next.compile(optimizer=Adam(lr=learning_rate), loss='mean_squared_error')
def replace_target_network(self):
if self.iteration_count % self.replace == 0:
self.q_next.set_weights(self.q_eval.get_weights())
def decay_epsilon(self):
self.epsilon = self.epsilon - self.epsilon_decay \
if self.epsilon > self.epsilon_min else self.epsilon_min
def act(self, state) -> int:
if random.random() > self.epsilon:
return np.argmax(self.q_eval.predict(np.expand_dims(state, 0))).item()
else:
return random.randrange(self.action_space)
def learn(self, *args, **kwargs):
self.memorise(kwargs.get('state'),
kwargs.get('action'),
kwargs.get('reward'),
kwargs.get('next_state'),
kwargs.get('done'))
if not self.training:
return
if self.get_mem_count() >= self.batch_size:
sample = self.sample(self.batch_size)
states = np.stack([i[0] for i in sample])
actions = np.array([i[1] for i in sample])
rewards = np.array([i[2] for i in sample])
next_states = np.stack([i[3] for i in sample])
dones = np.array([i[4] for i in sample])
self.replace_target_network()
q_eval = self.q_eval.predict(states)
q_next = self.q_next.predict(next_states)
q_target = q_eval[:]
indices = np.arange(self.batch_size)
q_target[indices, actions] = rewards + \
self.gamma*np.max(q_next, axis=1)*(1 - dones)
self.q_eval.train_on_batch(states, q_target)
self.decay_epsilon()
self.iteration_count += 1
def memorise(self, state, action, reward, next_state, done):
self.store(
state,
action,
reward,
next_state,
done)
def get_epsilon(self):
return self.epsilon
class SkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
super(SkipEnv, self).__init__(env)
self._skip = skip
def step(self, action):
t_reward = 0.0
done = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
t_reward += reward
if done:
break
return obs, t_reward, done, info
def reset(self):
self._obs_buffer = []
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class PreProcessFrame(gym.ObservationWrapper):
def __init__(self, env=None):
super(PreProcessFrame, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(80, 80, 1), dtype=np.uint8)
def observation(self, obs):
return PreProcessFrame.process(obs)
@staticmethod
def process(frame):
new_frame = np.reshape(frame, frame.shape).astype(np.float32)
new_frame = 0.299 * new_frame[:, :, 0] + 0.587 * new_frame[:, :, 1] + \
0.114 * new_frame[:, :, 2]
new_frame = new_frame[35:195:2, ::2].reshape(80, 80, 1)
return new_frame.astype(np.uint8)
class MoveImgChannel(gym.ObservationWrapper):
def __init__(self, env):
super(MoveImgChannel, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
shape=(self.observation_space.shape[-1],
self.observation_space.shape[0],
self.observation_space.shape[1]),
dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaleFrame(gym.ObservationWrapper):
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(BufferWrapper, self).__init__(env)
self.observation_space = gym.spaces.Box(
env.observation_space.low.repeat(n_steps, axis=0),
env.observation_space.high.repeat(n_steps, axis=0),
dtype=np.float32)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=np.float32)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
def make_env(env_name):
env = gym.make(env_name)
env = SkipEnv(env)
env = PreProcessFrame(env)
env = MoveImgChannel(env)
env = BufferWrapper(env, 4)
return ScaleFrame(env)
PYTORCH AGENT
import random
import gym
import torch
from torch.nn.modules import Module, Linear, Conv2d
import torch.nn.functional as fn
import numpy as np
from agent_base import Agent
from agent_augments.memory import ReplayMemory
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
print("Using device: {}".format(device))
class DQN(Module):
def __init__(self,
action_space,
transform,
channels=1,
learning_rate=0.001):
super(DQN, self).__init__()
self.in_layer = Conv2d(channels, 32, 3, 2)
self.hidden_conv_1 = Conv2d(32, 64, 5, 2)
self.hidden_conv_2 = Conv2d(64, 64, 3, 1)
self.hidden_fc1 = Linear(64 * 16 * 16, 512)
self.output = Linear(512, action_space)
self.loss = torch.nn.MSELoss()
self.transform = transform
self.optimizer = torch.optim.Adam(
self.parameters(), lr=learning_rate)
self.to(device)
self.train()
def forward(self, state):
in_out = fn.relu(self.in_layer(state))
in_out = fn.relu(self.hidden_conv_1(in_out))
in_out = fn.relu(self.hidden_conv_2(in_out))
in_out = in_out.view(-1, 64 * 16 * 16)
in_out = fn.relu(self.hidden_fc1(in_out))
return self.output(in_out)
class PongAgent(Agent, Module, ReplayMemory):
def __init__(self,
state_space,
action_space,
channels=4,
batch_size=64,
epsilon=1,
epsilon_min=0.01,
epsilon_decay=1e-5,
gamma=0.99,
learning_rate=0.0002,
save_freq=500,
training=True):
Module.__init__(self)
ReplayMemory.__init__(self, capacity=10000)
self.state_space = state_space
self.action_space = action_space
self.channels = channels
self.batch_size = batch_size
self.save_freq = save_freq
self.iteration_count = 0
self.training = training
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.gamma = gamma
self.replace_every = 1000
self.q_net = DQN(action_space, None, channels, learning_rate)
self.target_net = DQN(action_space, None, channels, learning_rate)
def update_target_network(self):
if self.iteration_count % self.replace_every == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
def decay_epsilon(self):
self.epsilon = self.epsilon - self.epsilon_decay \
if self.epsilon > self.epsilon_min else self.epsilon_min
def act(self, state) -> int:
if random.random() > self.epsilon:
state = torch.tensor(state, device=device)
return self.q_net(state.unsqueeze(0)).argmax(1).item()
else:
return random.randrange(self.action_space)
def learn(self, *args, **kwargs):
self.memorise(kwargs.get('state'),
kwargs.get('action'),
kwargs.get('reward'),
kwargs.get('next_state'),
kwargs.get('done'))
if not self.training:
return
if self.get_mem_count() >= self.batch_size:
sample = self.sample(self.batch_size)
states = torch.stack([i[0] for i in sample])
actions = torch.tensor([i[1] for i in sample], device=device)
rewards = torch.tensor([i[2] for i in sample], dtype=torch.float32, device=device)
next_states = torch.stack([i[3] for i in sample])
dones = torch.tensor([i[4] for i in sample], dtype=torch.uint8, device=device)
self.update_target_network()
current_q_vals = self.q_net(states)
next_q_vals = self.target_net(next_states)
q_target = current_q_vals.clone().detach()
q_target[torch.arange(states.size()[0]), actions] = rewards + (self.gamma * next_q_vals.max(dim=1)[0]) * (
~dones).float()
self.q_net.optimizer.zero_grad()
loss = self.q_net.loss(current_q_vals, q_target)
loss.backward()
self.q_net.optimizer.step()
self.decay_epsilon()
self.iteration_count += 1
def memorise(self, state, action, reward, next_state, done):
# state = self.transform(state).to(device)
# next_state = self.transform(next_state).to(device)
state = torch.tensor(state, device=device)
next_state = torch.tensor(next_state, device=device)
self.store(
state,
action,
reward,
next_state,
done)
def flatten(self, x):
flattened_count = 1
for dim in x.shape[1:]:
flattened_count *= dim
return x.view(-1, flattened_count)
def get_epsilon(self):
return self.epsilon
def save_model(self):
torch.save(self.q_net.state_dict(), "D:/models/q_net-agent-q-episode-{}".format(self.iteration_count))
torch.save(self.target_net.state_dict(), "D:/models/target_net-agent_q-episode-{}".format(self.iteration_count))
class SkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
super(SkipEnv, self).__init__(env)
self._skip = skip
def step(self, action):
t_reward = 0.0
done = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
t_reward += reward
if done:
break
return obs, t_reward, done, info
def reset(self):
self._obs_buffer = []
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class PreProcessFrame(gym.ObservationWrapper):
def __init__(self, env=None):
super(PreProcessFrame, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(80,80,1), dtype=np.uint8)
def observation(self, obs):
return PreProcessFrame.process(obs)
@staticmethod
def process(frame):
new_frame = np.reshape(frame, frame.shape).astype(np.float32)
new_frame = 0.299*new_frame[:,:,0] + 0.587*new_frame[:,:,1] + \
0.114*new_frame[:,:,2]
new_frame = new_frame[35:195:2, ::2].reshape(80,80,1)
return new_frame.astype(np.uint8)
class MoveImgChannel(gym.ObservationWrapper):
def __init__(self, env):
super(MoveImgChannel, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
shape=(self.observation_space.shape[-1],
self.observation_space.shape[0],
self.observation_space.shape[1]),
dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaleFrame(gym.ObservationWrapper):
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(BufferWrapper, self).__init__(env)
self.observation_space = gym.spaces.Box(
env.observation_space.low.repeat(n_steps, axis=0),
env.observation_space.high.repeat(n_steps, axis=0),
dtype=np.float32)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=np.float32)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
def make_env(env_name):
env = gym.make(env_name)
env = SkipEnv(env)
env = PreProcessFrame(env)
env = MoveImgChannel(env)
env = BufferWrapper(env, 4)
return ScaleFrame(env)