Model almost instantly produces "nan"

Hello all,
I’ve been attempting to create a simple implementation of the A2C algorithm to play the Atari games. While at first I had some success (model would train but performance was poor), I decided to make some modifications and foolishly didn’t back up my previous working example. My model collapses into “nan” alarmingly fast, which typically occurs when I apply softmax or distributions. This issue has also arisen during training steps. Interestingly, this happens more often during inference, when it is collecting state-action pairs. I have tried the following:

  • Gradient clipping (0.5 to 1.0)
  • Lower learning rates (0.0001 for actor and critic)
  • Clamping the logits to prevent large values from exploding the softmax.
  • Different activation functions (ReLU, leakyReLU, SELU for its normalisation effects)
  • Removing recurrent elements from my networks.
  • Torch anomaly detections (didn’t help as nans occur almost randomly and anywhere!)

Unfortunately, I’ve had no success. I’d really appreciate some help in rectifying this problem as well as some feedback back to help improve my PyTorch programming in future. Below I’ve included the error and the code of my A2C learning loop.

Error message received:

Episode 2 over, session frames : 871, total frames : 1635, current greedy : 0.9960, total rewards : -21.0000
Traceback (most recent call last):
  File "c:\Users\jakey\Desktop\RL\Atari\atari.py", line 71, in <module>
    dists = torch.distributions.Categorical(probs=A2C_Model.Actor.softmax(action_logits))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jakey\anaconda3\envs\PyTorch\Lib\site-packages\torch\distributions\categorical.py", line 71, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "C:\Users\jakey\anaconda3\envs\PyTorch\Lib\site-packages\torch\distributions\distribution.py", line 70, in __init__
    raise ValueError(
ValueError: Expected parameter probs (Tensor of shape (1, 6)) of distribution Categorical(probs: torch.Size([1, 6])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan]], device='cuda:0',
       grad_fn=<DivBackward0>)

Agent exploration loop:

    image_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Grayscale(),
        transforms.ToTensor()
    ])

    while not episode_over:

            obs = image_transform(obs)
            obs = torch.as_tensor((obs), dtype=torch.float32)
            obs = obs.unsqueeze(0).to(device=device)
            observations.append(obs)
            obs = torch.cat((observations[-7:][::2]), dim=1)
            while obs.shape[1] < 4:
                obs = torch.cat((obs, torch.zeros(1,1,210,160, device=device)), dim=1)

            A2C_Model.observations.append(obs)
            action_logits = A2C_Model.actor(obs)
            dists = torch.distributions.Categorical(probs=A2C_Model.Actor.softmax(action_logits))
            action = dists.sample()
            if random() < greedy:
                logits = torch.nn.functional.softmax(logits/temperature)
                dists = torch.distributions.Categorical(probs=A2C_Model.Actor.softmax(action_logits))
                action = dists.sample()

            A2C_Model.log_probs.append(dists.log_prob(action))
            A2C_Model.entropy.append(dists.entropy())
            obs, reward, term, trunc, info = env.step(action.cpu().detach().item())
            A2C_Model.rewards.append(reward)

            episode_over = term or trunc or False if len(A2C_Model.observations) < 10000 else True

A2C training loop:

    def update_model(self, discount=0.85, beta=0.001, n_steps=4, clippings=0.5):
        actor_loss = []
        critic_loss = []
        for t in range(len(self.observations)-1):
            Vt = self.critic(self.observations[t])
            with torch.no_grad():
                max_k = min(n_steps, len(self.observations)-1 - t)
                R = 0.0
                for k in range(max_k):
                    R += (discount ** k) * self.rewards[t + k]
                if t + n_steps <= len(self.observations)-2: 
                    R += (discount ** max_k) * self.critic(self.observations[t + max_k]).detach()
                Vtar = R

                advantage = Vtar - Vt.clone().detach()

            critic_loss.append(((Vtar - Vt) ** 2))
            actor_loss.append((-self.log_probs[t] * advantage.detach()) - beta*self.entropy[t])

        critic_loss = torch.stack(critic_loss)
        self.critic_optim.zero_grad()
        critic_loss.mean().backward()
        torch.nn.utils.clip_grad_norm_(self.Critic.parameters(), clippings)
        self.critic_optim.step()

        actor_loss = torch.stack(actor_loss)
        self.actor_optim.zero_grad()
        actor_loss.mean().backward()
        torch.nn.utils.clip_grad_norm_(self.Actor.parameters(), clippings)
        self.actor_optim.step()

        self.reset()   

Upon further investigating, I’ve discovered that an error must exist within my loss function. I tested and found that a nan result is being calculated for both actor and critic losses.

tensor(nan, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(nan, device='cuda:0', grad_fn=<MeanBackward0>)

Based on your suggestion the nan values are created in the loss function and the previously shared debug message:

ValueError: Expected parameter probs (Tensor of shape (1, 6)) of distribution Categorical(probs: torch.Size([1, 6])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan]], device='cuda:0',
       grad_fn=<DivBackward0>)

do you see any divisions inside the loss calculations which could create these invalid outputs?

I did some investigations as to where the nan values could be coming from and found that they’re coming from when I calculate the target value Vtar. As to where I get these values from, I’m completely unsure. The only possible region I think they could be coming from is the BatchNorm2d in my model’s convolutional layers. However, in a previous attempt to debug my model, I commented out the BatchNorm2d parts, and strangely, different activation functions in my convolutional layers can either make the nan values appear more suddenly like with SELU where it can appear after 2 training iterations or ReLU where a few dozen can pass before before I see nan values.

Below ill include my models architecture and the areas in my training function where the nan values are appearing.

Model Architecture (Architecture is identical between actor and critic):

        self.critic_head = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=(3,3), stride=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), stride=(2,2)),

            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=32, eps=0.001),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), stride=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(2,2)),

            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64, eps=0.001),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(2,2)),

            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64, eps=0.001),
            nn.Flatten(),
            nn.Linear(256, 64),
            nn.ReLU()
        )

        self.critic_GRU = nn.GRU(input_size=hidden_space, hidden_size=hidden_space, num_layers=recurrent_layers, batch_first=True)
        self.critic_network = nn.Sequential(
            nn.Linear(in_features=64, out_features=64),
            nn.SELU(),
            nn.Linear(in_features=64, out_features=1),
        )

Region of code where nan values originate from:

        for t in range(len(self.observations)-1):
            Vt = self.critic(self.observations[t])
            with torch.no_grad():
                max_k = min(n_steps, len(self.observations)-1 - t)
                R = 0.0
                for k in range(max_k):
                    R += (discount ** k) * self.rewards[t + k]
                if t + n_steps <= len(self.observations)-2: 
                    R += (discount ** max_k) * self.critic(self.observations[t + max_k]).detach()
                Vtar = R

                advantage = Vtar - Vt.clone().detach()

I hope the information I’ve provided here helps. I really am quite stumped with this problem, and thank you for the help it has been much appreciated.

So, it works now. I left my PC turned off for a few days and came back, and it’s working. I read another forum page where this happened to another user, but I guess it does indeed happen somewhere. I will link the forum page below.
Restarting resolved nan for a user