Help with A2C Implementation

I’ve been trying to implement a simple A2C online learning algorithm but the network does not seem to learn a good policy. I’m pretty positive that this is just my implementation being wrong and I would really appreciate some help figuring out where my math is wrong.

Around episode 200 or so the model begins to predict the action space of the cart pole environment in OpenAI Gym as [~0, ~.99]. This happens consistently, where the model repeatedly predicts the same action. To keep the post as brief as possible I haven’t included the model but it is a bunch of linear layers with ReLU activations with a critic and actor head respectively.

The actor head has a softmax output and the critic head has no output activation. Can anyone advise me as to what is wrong with my code?

A plot of the rewards over time:

Below is my code:

import gym
import torch
from models import A2CPolicyModel
import numpy as np
import matplotlib.pyplot as plt

#discount factor
GAMMA = 0.99
#entropy penalty coefficient
BETA = 0.001
LR = 1e-3

#create env
env = gym.make("CartPole-v1")

NUM_ACTIONS = env.action_space.n

#make buffers
#buffer = [] 
actor = A2CPolicyModel(4, NUM_ACTIONS)
optimizer = torch.optim.Adam(actor.parameters(), lr=LR)

episode_rewards = []

for episode in range(NUM_EPISODES):
	state = env.reset()
	done = False
	r = 0
	while not done:
		action_dist, value = actor(torch.from_numpy(state).reshape(1, len(state)).float())
		#take random step, optionally weight with actor probabilities
		dist = torch.distributions.Categorical(probs=action_dist)
		act = dist.sample()
		#act = np.random.choice(NUM_ACTIONS, p=np.asarray(torch.squeeze(action_dist.cpu().detach())))
		new_state, reward, done, _ = env.step(act.detach().data.numpy()[0])

		r += reward

		#A(s,a) = r_t+1 + GAMMA*V(s_t+1) - V(s_t)
		if not done:
			_, future_value = actor(torch.from_numpy(new_state).reshape(1, len(new_state)).float())
			advantage = reward + GAMMA*future_value - value
			advantage = torch.tensor(reward)

		#calc critic loss: 1/T ||A(s,t)||^2
		critic_loss = torch.pow(advantage, 2)#.mean()

		#calc policy loss
		#log_probs = torch.log(action_dist)
		policy_loss = -(advantage.detach()*(dist.log_prob(act)))#.mean()

		#calc entropy
		entropy_penalty = dist.entropy()#torch.mean(BETA*action_dist*log_probs)


		loss = (.5*critic_loss) + policy_loss - (BETA*entropy_penalty)

		state = new_state


#plot episode rewards
plt.scatter(range(NUM_EPISODES), episode_rewards, s=2)
plt.ylabel("Net Reward")
plt.title("Net Reward over each episode")