Pytorch and keras ddqn seem identical, only keras learns

I followed a tutorial for ddqn to beat pong, it beats it with a perfect score in keras, but trying to translate it to pytorch it doesn’t learn at all. What am I missing? I pasted all the code for each model in case I miss something subtle.

KERAS AGENT

import random

import gym

from agent_base import Agent
from agent_augments.memory import ReplayMemory

import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten
from keras.optimizers import Adam


class PongAgent(Agent, ReplayMemory):

    def __init__(self,
                 state_space,
                 action_space,
                 channels=4,
                 batch_size=64,
                 epsilon=1,
                 epsilon_min=0.01,
                 epsilon_decay=1e-5,
                 gamma=0.99,
                 learning_rate=0.0002,
                 save_freq=500,
                 training=True):
        ReplayMemory.__init__(self, capacity=10000)

        self.state_space = state_space
        self.action_space = action_space
        self.channels = channels
        self.batch_size = batch_size
        self.save_freq = save_freq
        self.iteration_count = 0
        self.training = training
        self.replace = 1000

        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma

        self.q_eval = Sequential([
            Conv2D(filters=32, kernel_size=3, strides=2, activation='relu',
                   input_shape=state_space, data_format='channels_first'),
            Conv2D(filters=64, kernel_size=5, strides=2, activation='relu',
                   input_shape=state_space, data_format='channels_first'),
            Conv2D(filters=64, kernel_size=3, strides=1, activation='relu',
                   input_shape=state_space, data_format='channels_first'),
            Flatten(),
            Dense(512, activation='relu'),
            Dense(action_space)]
        )
        self.q_eval.compile(optimizer=Adam(lr=learning_rate), loss='mean_squared_error')

        self.q_next = Sequential([
            Conv2D(filters=32, kernel_size=3, strides=2, activation='relu',
                   input_shape=state_space, data_format='channels_first'),
            Conv2D(filters=64, kernel_size=5, strides=2, activation='relu',
                   input_shape=state_space, data_format='channels_first'),
            Conv2D(filters=64, kernel_size=3, strides=1, activation='relu',
                   input_shape=state_space, data_format='channels_first'),
            Flatten(),
            Dense(512, activation='relu'),
            Dense(action_space)]
        )
        self.q_next.compile(optimizer=Adam(lr=learning_rate), loss='mean_squared_error')

    def replace_target_network(self):
        if self.iteration_count % self.replace == 0:
            self.q_next.set_weights(self.q_eval.get_weights())

    def decay_epsilon(self):
        self.epsilon = self.epsilon - self.epsilon_decay \
            if self.epsilon > self.epsilon_min else self.epsilon_min

    def act(self, state) -> int:
        if random.random() > self.epsilon:
            return np.argmax(self.q_eval.predict(np.expand_dims(state, 0))).item()
        else:
            return random.randrange(self.action_space)

    def learn(self, *args, **kwargs):
        self.memorise(kwargs.get('state'),
                      kwargs.get('action'),
                      kwargs.get('reward'),
                      kwargs.get('next_state'),
                      kwargs.get('done'))

        if not self.training:
            return

        if self.get_mem_count() >= self.batch_size:
            sample = self.sample(self.batch_size)
            states = np.stack([i[0] for i in sample])
            actions = np.array([i[1] for i in sample])
            rewards = np.array([i[2] for i in sample])
            next_states = np.stack([i[3] for i in sample])
            dones = np.array([i[4] for i in sample])

            self.replace_target_network()

            q_eval = self.q_eval.predict(states)

            q_next = self.q_next.predict(next_states)

            q_target = q_eval[:]
            indices = np.arange(self.batch_size)
            q_target[indices, actions] = rewards + \
                                    self.gamma*np.max(q_next, axis=1)*(1 - dones)
            self.q_eval.train_on_batch(states, q_target)

            self.decay_epsilon()
            self.iteration_count += 1

    def memorise(self, state, action, reward, next_state, done):
        self.store(
            state,
            action,
            reward,
            next_state,
            done)

    def get_epsilon(self):
        return self.epsilon

class SkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        super(SkipEnv, self).__init__(env)
        self._skip = skip

    def step(self, action):
        t_reward = 0.0
        done = False
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            t_reward += reward
            if done:
                break
        return obs, t_reward, done, info

    def reset(self):
        self._obs_buffer = []
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class PreProcessFrame(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(PreProcessFrame, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255,
                                                shape=(80, 80, 1), dtype=np.uint8)

    def observation(self, obs):
        return PreProcessFrame.process(obs)

    @staticmethod
    def process(frame):
        new_frame = np.reshape(frame, frame.shape).astype(np.float32)

        new_frame = 0.299 * new_frame[:, :, 0] + 0.587 * new_frame[:, :, 1] + \
                    0.114 * new_frame[:, :, 2]

        new_frame = new_frame[35:195:2, ::2].reshape(80, 80, 1)

        return new_frame.astype(np.uint8)


class MoveImgChannel(gym.ObservationWrapper):
    def __init__(self, env):
        super(MoveImgChannel, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
                                                shape=(self.observation_space.shape[-1],
                                                       self.observation_space.shape[0],
                                                       self.observation_space.shape[1]),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaleFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps):
        super(BufferWrapper, self).__init__(env)
        self.observation_space = gym.spaces.Box(
            env.observation_space.low.repeat(n_steps, axis=0),
            env.observation_space.high.repeat(n_steps, axis=0),
            dtype=np.float32)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=np.float32)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env_name):
    env = gym.make(env_name)
    env = SkipEnv(env)
    env = PreProcessFrame(env)
    env = MoveImgChannel(env)
    env = BufferWrapper(env, 4)
    return ScaleFrame(env)

PYTORCH AGENT

import random

import gym
import torch
from torch.nn.modules import Module, Linear, Conv2d
import torch.nn.functional as fn
import numpy as np
from agent_base import Agent
from agent_augments.memory import ReplayMemory

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print("Using device: {}".format(device))


class DQN(Module):

    def __init__(self,
                 action_space,
                 transform,
                 channels=1,
                 learning_rate=0.001):
        super(DQN, self).__init__()

        self.in_layer = Conv2d(channels, 32, 3, 2)
        self.hidden_conv_1 = Conv2d(32, 64, 5, 2)
        self.hidden_conv_2 = Conv2d(64, 64, 3, 1)
        self.hidden_fc1 = Linear(64 * 16 * 16, 512)
        self.output = Linear(512, action_space)
        self.loss = torch.nn.MSELoss()

        self.transform = transform

        self.optimizer = torch.optim.Adam(
            self.parameters(), lr=learning_rate)

        self.to(device)

        self.train()

    def forward(self, state):
        in_out = fn.relu(self.in_layer(state))
        in_out = fn.relu(self.hidden_conv_1(in_out))
        in_out = fn.relu(self.hidden_conv_2(in_out))
        in_out = in_out.view(-1, 64 * 16 * 16)
        in_out = fn.relu(self.hidden_fc1(in_out))
        return self.output(in_out)

class PongAgent(Agent, Module, ReplayMemory):

    def __init__(self,
                 state_space,
                 action_space,
                 channels=4,
                 batch_size=64,
                 epsilon=1,
                 epsilon_min=0.01,
                 epsilon_decay=1e-5,
                 gamma=0.99,
                 learning_rate=0.0002,
                 save_freq=500,
                 training=True):
        Module.__init__(self)
        ReplayMemory.__init__(self, capacity=10000)

        self.state_space = state_space
        self.action_space = action_space
        self.channels = channels
        self.batch_size = batch_size
        self.save_freq = save_freq
        self.iteration_count = 0
        self.training = training

        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma
        self.replace_every = 1000

        self.q_net = DQN(action_space, None, channels, learning_rate)
        self.target_net = DQN(action_space, None, channels, learning_rate)

    def update_target_network(self):
        if self.iteration_count % self.replace_every == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = self.epsilon - self.epsilon_decay \
            if self.epsilon > self.epsilon_min else self.epsilon_min

    def act(self, state) -> int:
        if random.random() > self.epsilon:
            state = torch.tensor(state, device=device)
            return self.q_net(state.unsqueeze(0)).argmax(1).item()
        else:
            return random.randrange(self.action_space)

    def learn(self, *args, **kwargs):
        self.memorise(kwargs.get('state'),
                      kwargs.get('action'),
                      kwargs.get('reward'),
                      kwargs.get('next_state'),
                      kwargs.get('done'))

        if not self.training:
            return

        if self.get_mem_count() >= self.batch_size:
            sample = self.sample(self.batch_size)
            states = torch.stack([i[0] for i in sample])
            actions = torch.tensor([i[1] for i in sample], device=device)
            rewards = torch.tensor([i[2] for i in sample], dtype=torch.float32, device=device)
            next_states = torch.stack([i[3] for i in sample])
            dones = torch.tensor([i[4] for i in sample], dtype=torch.uint8, device=device)

            self.update_target_network()

            current_q_vals = self.q_net(states)
            next_q_vals = self.target_net(next_states)
            q_target = current_q_vals.clone().detach()
            q_target[torch.arange(states.size()[0]), actions] = rewards + (self.gamma * next_q_vals.max(dim=1)[0]) * (
                ~dones).float()

            self.q_net.optimizer.zero_grad()
            loss = self.q_net.loss(current_q_vals, q_target)
            loss.backward()
            self.q_net.optimizer.step()

            self.decay_epsilon()
            self.iteration_count += 1

    def memorise(self, state, action, reward, next_state, done):
        # state = self.transform(state).to(device)
        # next_state = self.transform(next_state).to(device)
        state = torch.tensor(state, device=device)
        next_state = torch.tensor(next_state, device=device)

        self.store(
            state,
            action,
            reward,
            next_state,
            done)

    def flatten(self, x):
        flattened_count = 1
        for dim in x.shape[1:]:
            flattened_count *= dim
        return x.view(-1, flattened_count)

    def get_epsilon(self):
        return self.epsilon

    def save_model(self):
        torch.save(self.q_net.state_dict(), "D:/models/q_net-agent-q-episode-{}".format(self.iteration_count))
        torch.save(self.target_net.state_dict(), "D:/models/target_net-agent_q-episode-{}".format(self.iteration_count))


class SkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        super(SkipEnv, self).__init__(env)
        self._skip = skip

    def step(self, action):
        t_reward = 0.0
        done = False
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            t_reward += reward
            if done:
                break
        return obs, t_reward, done, info

    def reset(self):
        self._obs_buffer = []
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs

class PreProcessFrame(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(PreProcessFrame, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255,
                                                shape=(80,80,1), dtype=np.uint8)
    def observation(self, obs):
        return PreProcessFrame.process(obs)

    @staticmethod
    def process(frame):

        new_frame = np.reshape(frame, frame.shape).astype(np.float32)

        new_frame = 0.299*new_frame[:,:,0] + 0.587*new_frame[:,:,1] + \
                    0.114*new_frame[:,:,2]

        new_frame = new_frame[35:195:2, ::2].reshape(80,80,1)

        return new_frame.astype(np.uint8)

class MoveImgChannel(gym.ObservationWrapper):
    def __init__(self, env):
        super(MoveImgChannel, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
                            shape=(self.observation_space.shape[-1],
                                   self.observation_space.shape[0],
                                   self.observation_space.shape[1]),
                            dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)

class ScaleFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0

class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps):
        super(BufferWrapper, self).__init__(env)
        self.observation_space = gym.spaces.Box(
                             env.observation_space.low.repeat(n_steps, axis=0),
                             env.observation_space.high.repeat(n_steps, axis=0),
                             dtype=np.float32)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=np.float32)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer

def make_env(env_name):
    env = gym.make(env_name)
    env = SkipEnv(env)
    env = PreProcessFrame(env)
    env = MoveImgChannel(env)
    env = BufferWrapper(env, 4)
    return ScaleFrame(env)

I figured out what was going wrong, I had been using the tilde operator before to invert uint8 tensors, but recently I had updated to the latest version of pytorch that seems to have changed how the operator works. It was changing the done values to 255