REINFORCE algorithm fails to learn

Hi

I tried to use CNN and REINFORCE to train models to play Pong of Atari games. However, my model did not learn anything after many episodes. The reward is always -21. Is there any bugs in my program?

The whole code:

import os
import cv2
import copy
import random
import gc
import gym
import numpy as np
import torch
import ipywidgets as widgets
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

from gym import spaces
from tqdm import tqdm
from collections import deque
from IPython import display
from IPython.display import clear_output
from matplotlib import animation

cv2.ocl.setUseOpenCL(False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

gif_path = './gif/'
model_path = './model/'
result_path = './result/'
reward_path = './reward/'

if not os.path.exists(gif_path):
    os.makedirs(gif_path)
    
if not os.path.exists(model_path):
    os.makedirs(model_path)
    
if not os.path.exists(result_path):
    os.makedirs(result_path)
    
if not os.path.exists(reward_path):
    os.makedirs(reward_path)
class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        super().__init__(env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
    
    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
        
        assert noops > 0
        
        obs = None
        
        for _ in range(noops):
            obs, _, done, _, info = self.env.step(self.noop_action)
        
            if done:
                obs = self.env.reset(**kwargs)
        
        return obs, info
    
    def step(self, ac):
        return self.env.step(ac)

class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3
    
    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, done, _, info = self.env.step(1)
        
        if done:
            self.env.reset(**kwargs)
        
        obs, _, done, _, info = self.env.step(2)
        
        if done:
            self.env.reset(**kwargs)
        
        return obs, info
    
    def step(self, ac):
        return self.env.step(ac)

class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.lives = 0
        self.was_real_done = True
    
    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        self.was_real_done = done
        lives = self.env.unwrapped.ale.lives()
        
        if lives < self.lives and lives > 0:
            done = True
        
        self.lives = lives
        
        return obs, reward, done, truncated, info
    
    def reset(self, **kwargs):
        if self.was_real_done:
            obs, info = self.env.reset(**kwargs)
        else:
            obs, _, _, _, info = self.env.step(0)
        
        self.lives = self.env.unwrapped.ale.lives()
        
        return obs, info

class MaxandSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._obs_buffer = np.zeros((2, ) + env.observation_space.shape, dtype=np.uint8)
        self._skip = skip
    
    def reset(self):
        return self.env.reset()
    
    def step(self, action):
        total_reward = 0.0
        done = None
        
        for i in range(self._skip):
            obs, reward, done, truncated, info = self.env.step(action)
            
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            
            total_reward += reward
            
            if done:
                break
        
        max_frame = self._obs_buffer.max(axis=0)
        
        return max_frame, total_reward, done, truncated, info
    
    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

class ClipRewardEnv(gym.RewardWrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def reward(self, reward):
        return np.sign(reward)

class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.width = 84
        self.height = 84
        self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8)
    
    def observation(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        
        return frame[:, :, None]

class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        super().__init__(env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
    
    def reset(self):
        ob, info = self.env.reset()
        
        for _ in range(self.k):
            self.frames.append(ob)
        
        return self._get_ob(), info
    
    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        self.frames.append(ob)
        
        return self._get_ob(), reward, done, info
        
    def _get_ob(self):
        assert len(self.frames) == self.k
        
        return LazyFrames(list(self.frames))

class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def observation(self, observation):
        return np.array(observation).astype(np.float32) / 255.0

class LazyFrames(object):
    def __init__(self, frames):
        self._frames = frames
        self._out = None
    
    def _force(self):
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=2)
            self._frames = None
        
        return self._out
    
    def __array__(self, dtype=None):
        out = self._force()
        
        if dtype is not None:
            out = out.astype(dtype)
        
        return out
    
    def __len__(self):
        return len(self._force())
    
    def __getitem__(self, i):
        return self._force()[i]

def make_atari(env_id, render_mode=None):
    env = gym.make(env_id, render_mode=render_mode)     
    assert 'NoFrameskip' in env.spec.id
    env = NoopResetEnv(env, noop_max=30)
    env = MaxandSkipEnv(env, skip=4)
    
    return env

def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
    if episode_life:
        env = EpisodicLifeEnv(env)
    
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    
    env = WarpFrame(env)
    
    if scale:
        env = ScaledFloatFrame(env)
        
    if clip_rewards:
        env = ClipRewardEnv(env)
    
    if frame_stack:
        env = FrameStack(env, 4)
    
    return env

class ImagetoPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.uint8)
    
    def observation(self, observation):
        return np.swapaxes(observation, 2, 0)

def wrap_pytorch(env):
    return ImagetoPyTorch(env)

env_id = 'PongNoFrameskip-v4'
env = make_atari(env_id, render_mode='rgb_array')
env = wrap_deepmind(env)
env = wrap_pytorch(env)

class REINFORCE(nn.Module):
    def __init__(self, input_shape, num_actions):
        super().__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        
        self.features = nn.Sequential(
                                    nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
                                    nn.ReLU(),
                                    nn.Conv2d(32, 64, kernel_size=4, stride=2),
                                    nn.ReLU(),
                                    nn.Conv2d(64, 64, kernel_size=3, stride=1),
                                    nn.ReLU()
                                    )
        
        self.fc = nn.Sequential(
                                nn.Linear(self.feature_size(), 512),
                                nn.ReLU(),
                                nn.Linear(512, self.num_actions)
                                )
        
        self.states = []
        self.actions = []
        self.rewards = []
    
    def remember(self, state, action, reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        
    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        
    def forward(self, x):
        x = np.array(x)
        x = torch.tensor(x, device=device, dtype=torch.float32)
        x = x.float()
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        probs = F.softmax(x, dim=-1).squeeze(0)
      
        return probs
    
    def feature_size(self):
        return self.features(torch.zeros(1, *self.input_shape)).view(1, -1).size(1)
    
    def select_action(self, state):
        probs = model(state)
        action = np.random.choice(env.action_space.n, p=probs.detach().numpy())
        
        return action
    
    def train_policy(self, device=device):
        actions = torch.tensor(self.actions).to(device)
        rewards_to_go = self.rewards_to_go()
        rewards_to_go = torch.tensor(rewards_to_go).to(device)
        rewards_to_go = self.normalize_rewards(rewards_to_go)
        
        probs = model(self.states)
        
        log_probs = torch.log(probs) 
        log_probs_for_actions = log_probs[range(len(actions)), actions]

        loss = -torch.mean(log_probs_for_actions * rewards_to_go)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        self.clear()
        
        return loss.item()
    
    def rewards_to_go(self, gamma=0.99):
        T = len(self.rewards)
        
        rewards_to_go = [0] * T
        rewards_to_go[T-1] = self.rewards[T-1]
        
        for i in range(T-2, -1, -1):
            rewards_to_go[i] = self.rewards[i] + gamma * rewards_to_go[i+1]
        
        return rewards_to_go
    
    def normalize_rewards(self, rewards):
        normalized_rewards = (rewards - rewards.mean()) / rewards.std()
        
        return normalized_rewards

model = REINFORCE(env.observation_space.shape, env.action_space.n)
model = model
optimizer = optim.Adam(model.parameters(), lr=0.0001)

def plot(episode, rewards, losses, game):
    clear_output(True)
    plt.figure(figsize=(20, 5));
    plt.subplot(131);
    plt.title('episode %s. reward: %s' % (episode, np.mean(rewards[-10:])));
    plt.plot(rewards);
    plt.subplot(132);
    plt.title('loss');
    plt.plot(losses);
    plt.subplot(133);
    plt.title(f'ep: {episode}');
    plt.imshow(game);
    plt.axis('off');
    plt.savefig(f'{result_path}{env.spec.id}-{model.__class__.__name__}.png')
    plt.show();

def train():
    episodes = 1_000_000
    losses = []
    total_reward = []
    
    for episode in tqdm(range(1, episodes + 1)):  
        state, _ = env.reset()
        done = False
        episode_reward = 0
        
        while not done:
            action = model.select_action([state])
            next_state, reward, done, _, _ = env.step(action)
            model.remember(state, action, reward)

            state = next_state
            episode_reward += reward
        
            if done:
                loss = model.train_policy()
                break
        
        losses.append(loss)
        total_reward.append(episode_reward)
        
        if episode % 1_000 == 0:
            rgb_array = env.render();
            plot(episode, total_reward, losses, rgb_array)

train()