Hello all,
I’ve been attempting to create a simple implementation of the A2C algorithm to play the Atari games. While at first I had some success (model would train but performance was poor), I decided to make some modifications and foolishly didn’t back up my previous working example. My model collapses into “nan” alarmingly fast, which typically occurs when I apply softmax or distributions. This issue has also arisen during training steps. Interestingly, this happens more often during inference, when it is collecting state-action pairs. I have tried the following:
- Gradient clipping (0.5 to 1.0)
- Lower learning rates (0.0001 for actor and critic)
- Clamping the logits to prevent large values from exploding the softmax.
- Different activation functions (ReLU, leakyReLU, SELU for its normalisation effects)
- Removing recurrent elements from my networks.
- Torch anomaly detections (didn’t help as nans occur almost randomly and anywhere!)
Unfortunately, I’ve had no success. I’d really appreciate some help in rectifying this problem as well as some feedback back to help improve my PyTorch programming in future. Below I’ve included the error and the code of my A2C learning loop.
Error message received:
Episode 2 over, session frames : 871, total frames : 1635, current greedy : 0.9960, total rewards : -21.0000
Traceback (most recent call last):
File "c:\Users\jakey\Desktop\RL\Atari\atari.py", line 71, in <module>
dists = torch.distributions.Categorical(probs=A2C_Model.Actor.softmax(action_logits))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\jakey\anaconda3\envs\PyTorch\Lib\site-packages\torch\distributions\categorical.py", line 71, in __init__
super().__init__(batch_shape, validate_args=validate_args)
File "C:\Users\jakey\anaconda3\envs\PyTorch\Lib\site-packages\torch\distributions\distribution.py", line 70, in __init__
raise ValueError(
ValueError: Expected parameter probs (Tensor of shape (1, 6)) of distribution Categorical(probs: torch.Size([1, 6])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan]], device='cuda:0',
grad_fn=<DivBackward0>)
Agent exploration loop:
image_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(),
transforms.ToTensor()
])
while not episode_over:
obs = image_transform(obs)
obs = torch.as_tensor((obs), dtype=torch.float32)
obs = obs.unsqueeze(0).to(device=device)
observations.append(obs)
obs = torch.cat((observations[-7:][::2]), dim=1)
while obs.shape[1] < 4:
obs = torch.cat((obs, torch.zeros(1,1,210,160, device=device)), dim=1)
A2C_Model.observations.append(obs)
action_logits = A2C_Model.actor(obs)
dists = torch.distributions.Categorical(probs=A2C_Model.Actor.softmax(action_logits))
action = dists.sample()
if random() < greedy:
logits = torch.nn.functional.softmax(logits/temperature)
dists = torch.distributions.Categorical(probs=A2C_Model.Actor.softmax(action_logits))
action = dists.sample()
A2C_Model.log_probs.append(dists.log_prob(action))
A2C_Model.entropy.append(dists.entropy())
obs, reward, term, trunc, info = env.step(action.cpu().detach().item())
A2C_Model.rewards.append(reward)
episode_over = term or trunc or False if len(A2C_Model.observations) < 10000 else True
A2C training loop:
def update_model(self, discount=0.85, beta=0.001, n_steps=4, clippings=0.5):
actor_loss = []
critic_loss = []
for t in range(len(self.observations)-1):
Vt = self.critic(self.observations[t])
with torch.no_grad():
max_k = min(n_steps, len(self.observations)-1 - t)
R = 0.0
for k in range(max_k):
R += (discount ** k) * self.rewards[t + k]
if t + n_steps <= len(self.observations)-2:
R += (discount ** max_k) * self.critic(self.observations[t + max_k]).detach()
Vtar = R
advantage = Vtar - Vt.clone().detach()
critic_loss.append(((Vtar - Vt) ** 2))
actor_loss.append((-self.log_probs[t] * advantage.detach()) - beta*self.entropy[t])
critic_loss = torch.stack(critic_loss)
self.critic_optim.zero_grad()
critic_loss.mean().backward()
torch.nn.utils.clip_grad_norm_(self.Critic.parameters(), clippings)
self.critic_optim.step()
actor_loss = torch.stack(actor_loss)
self.actor_optim.zero_grad()
actor_loss.mean().backward()
torch.nn.utils.clip_grad_norm_(self.Actor.parameters(), clippings)
self.actor_optim.step()
self.reset()