Deep Active Inference: Issues with NaN predictions

Hello there.

I am attempting to create an “Active Inference” agent by means of several neural network models. Active Inference is a similar paradigm to Reinforcement Learning, so I thought it best to post here.

There is a very large amount of context that might be relevant to my queries, but I will be as terse as possible here. I am more than happy to elaborate on anything, please do ask about anything at all.

In essence, I want to blend the use of deep active inference and heuristic tree search for planning. As an intermediate step toward this end, I simply want to blend the approaches found in these two papers: https://www.frontiersin.org/articles/10.3389/fncom.2020.574372/full and [2009.03622] Deep Active Inference for Partially Observable MDPs.

I am using the CartPole-v1 environment: Cart Pole - Gymnasium Documentation Hence the action space has carnality 2 and the state space is 4-dimensional. My code is a heavily modified version of that found in the repo from the second of the above papers.

The first paper uses various Gaussian neural networks to approximate the agent’s generative model. These networks can then be trained on a free energy loss function, thus instantiating variational inference of the kind used in the “Active Inference” paradigm. The second paper finesses the issue of evaluating future actions via a bootstrapped estimate of the Expected Free Energy. This allows the agent to “plan” without explicitly maintaining a search tree over counterfactual trajectories.

I am encountering a very simple error in the forward operation of one of my neural networks. I don’t think the error has anything to do with the theory of Active Inference per se, so it should be amenable to anyone with sufficient knowledge of PyTorch - it seems that this is not me. Basically, after some period of time, one of my networks starts to produce NaN outputs. Here is the exact output:

Traceback (most recent call last):
  File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 2025, in <module>
    agent.train_models()
  File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 1755, in train_models
    VFE, value_net_psi_loss, expected_log_ev = self.learn()
  File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 1481, in learn
    expected_log_ev = self.mc_expected_log_evidence(z_batch_t1)
  File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 1315, in mc_expected_log_evidence
    multivariate_normal_p = torch.distributions.MultivariateNormal(
  File "/home/...multivariate_normal.py", line 146, in __init__
    super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  File "/home/...distribution.py", line 55, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (32, 4)) of distribution MultivariateNormal(loc: torch.Size([32, 4]), covariance_matrix: torch.Size([32, 4, 4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]], grad_fn=<ExpandBackward0>)

Now I’ve traced the source of this error all the way back to my get_mini_batches function:

def get_mini_batches(self):

        all_obs_batch, all_actions_batch, reward_batch_t1, terminated_batch_t2, truncated_batch_t2, hidden_state_batch = self.memory.sample(
            self.obs_indices, self.action_indices, self.reward_indices, 
            self.terminated_indices, self.truncated_indices, self.hidden_state_indices, 
            self.max_n_indices, self.batch_size
        )

        # Retrieve a batch of inferred hidden states for 3 consecutive points in time
        inferred_state_batch_t0 = hidden_state_batch[:, 0].view([self.batch_size] + [dim for dim in self.obs_shape])
        inferred_state_batch_t1 = hidden_state_batch[:, 1].view([self.batch_size] + [dim for dim in self.obs_shape])
        inferred_state_batch_t2 = hidden_state_batch[:, 2].view([self.batch_size] + [dim for dim in self.obs_shape])

        # Retrieve a batch of observations, hidden states, and actions for consecutive points in time
        # obs_batch_t0 = all_obs_batch[:, 0, :].view(self.batch_size, -1)  # Most recent observation
        obs_batch_t1 = all_obs_batch[:, 1, :].view(self.batch_size, -1)  # Second most recent observation
        obs_batch_t2 = all_obs_batch[:, 2, :].view(self.batch_size, -1)  # Third most recent observation
        obs_batch_t3 = all_obs_batch[:, 3, :].view(self.batch_size, -1)  # Fourth most recent observation 

        # Retrieve the agent's action history for time t0, t1, t2 and t3
        action_batch_t0 = all_actions_batch[:, 0].unsqueeze(1)
        action_batch_t1 = all_actions_batch[:, 1].unsqueeze(1)
        action_batch_t2 = all_actions_batch[:, 2].unsqueeze(1)
        action_batch_t3 = all_actions_batch[:, 3].unsqueeze(1)

        q_phi_inputs_t0 = torch.cat((inferred_state_batch_t0, action_batch_t1, obs_batch_t1), dim = 1) # s_t, a_{t + 1}, o_{t + 1}
        q_phi_inputs_t1 = torch.cat((inferred_state_batch_t1, action_batch_t2, obs_batch_t2), dim = 1) # s_{t + 1}, a_{t + 2}, o_{t + 2}
        q_phi_inputs_t2 = torch.cat((inferred_state_batch_t2, action_batch_t3, obs_batch_t3), dim = 1) # s_{t + 2}, a_{t + 3}, o_{t + 3}

        # Retrieve a batch of distributions over states
        state_mu_batch_t0, state_logvar_batch_t0 = self.posterior_transition_net_phi(q_phi_inputs_t0) # \mu{s_t}, \log{\Sigma^2(s_t)}
        state_mu_batch_t1, state_logvar_batch_t1 = self.posterior_transition_net_phi(q_phi_inputs_t1) # \mu{s_{t + 1}}, \log{\Sigma^2(s_{t + 1})}
        state_mu_batch_t2, state_logvar_batch_t2 = self.posterior_transition_net_phi(q_phi_inputs_t2) # \mu{s_{t + 2}}, \log{\Sigma^2(s_{t + 2})}

        # Reparameterize the distribution over states for time t0 and t1
        z_batch_t0 = self.posterior_transition_net_phi.rsample(state_mu_batch_t0, state_logvar_batch_t0)
        z_batch_t1 = self.posterior_transition_net_phi.rsample(state_mu_batch_t1, state_logvar_batch_t1)

        # At time t0 predict the state at time t1:
        X = torch.cat((state_mu_batch_t0.detach(), action_batch_t0.float()), dim = 1)
        pred_batch_mean_t0t1, pred_batch_logvar_t0t1 = self.prior_transition_net_theta(X)

        # Determine the prediction error wrt time t0-t1 using state KL Divergence:
        pred_error_batch_t0t1 = torch.sum(
            self.gaussian_kl_div(
                pred_batch_mean_t0t1, torch.exp(pred_batch_logvar_t0t1),
                state_mu_batch_t1, torch.exp(state_logvar_batch_t1)
            ), dim=1
        ).unsqueeze(1)

        print(f"\n\nget_mini_batches - state_mu_batch_t0:\n{state_mu_batch_t0}")
        print(f"get_mini_batches - state_logvar_batch_t0:\n{state_logvar_batch_t0}")
        print(f"get_mini_batches - state_mu_batch_t1:\n{state_mu_batch_t1}")
        print(f"get_mini_batches - state_logvar_batch_t1:\n{state_logvar_batch_t1}\n\n")

        return (
            state_mu_batch_t1, state_logvar_batch_t1, 
            state_mu_batch_t2, state_logvar_batch_t2, 
            action_batch_t1, reward_batch_t1, 
            terminated_batch_t2, truncated_batch_t2, pred_error_batch_t0t1,
            obs_batch_t1, state_mu_batch_t1,
            state_logvar_batch_t1, z_batch_t0, z_batch_t1
        )

For which the above print statements furnish:

get_mini_batches - state_mu_batch_t0:
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
get_mini_batches - state_logvar_batch_t0:
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
get_mini_batches - state_mu_batch_t1:
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
get_mini_batches - state_logvar_batch_t1:
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)

If would appear that my self.posterior_transition_net_phi is the culprit. I’ve checked the inputs to self.posterior_transition_net_phi, none of them ever seem problematic; there are no inf or NaN values in the inputs.

For the sake of completeness, here is the constructor to my “Agent” class. This is where I initialise the self.posterior_transition_net_phi network:

class Agent():
    
    def __init__(self, argv):
        
        self.set_parameters(argv) # Set parameters
        
        # self.obs_shape = (4,) 
        self.obs_shape = self.env.observation_space.shape # The shape of observations
        self.obs_size = np.prod(self.obs_shape) # The size of the observation
        self.n_actions = self.env.action_space.n # The number of actions available to the agent
        self.all_actions = [0, 1]
        self.freeze_cntr = 0 # Keeps track of when to (un)freeze the target network

        # specify the goal prior distribution: p_C:
        self.goal_mu = torch.tensor([0, 0, 0, 0]).to(self.device)
        self.goal_var = torch.tensor([1, 1, 1, 1]).to(self.device)

        # Generative state prior: P_theta
        self.prior_transition_net_theta = MVGaussianModel(
            self.latent_state_dim + 1,      # inputs:  s_{t-1}, a_{t-1}
            self.latent_state_dim,          # outputs: mu(s_t), sigma(s_t)
            self.n_hidden_gen_trans,        # hidden layer
            lr = self.lr_gen_trans,         # learning rate
            device = self.device,
            name = 'prior_theta'
        )

        # Variational state posterior: Q_phi
        self.posterior_transition_net_phi = MVGaussianModel(
            2 * self.latent_state_dim + 1,  # inputs:  s_{t-1}, a_{t-1}, o_t
            self.latent_state_dim,          # outputs: mu(s_t), sigma(s_t)
            self.n_hidden_var_trans,        # hidden layer
            lr = self.lr_var_trans,         # learning rate
            device = self.device,
            name = 'posterior_phi'
        )

        # Generative observation likelihood: P_xi
        self.generative_observation_net_xi = MVGaussianModel(
            self.latent_state_dim,          # input:   s_t (reparam'd state sample)
            self.latent_state_dim,          # outputs: mu(o_t), sigma(o_t)
            self.n_hidden_gen_obs,          # hidden layer
            lr = self.lr_gen_obs,           # learning rate
            device = self.device,
            name = 'obs_xi'
        )

        # Variational Policy prior: Q_nu
        self.policy_net_nu = Model(
            2 * self.latent_state_dim,      # inputs: mu(s_t), sigma(s_t)
            self.n_actions,                 # output: categorical action probs for all actions in action space
            self.n_hidden_pol,              # hidden layer
            lr = self.lr_pol,               # learning rate
            softmax = True, 
            device = self.device
        )

        # EFE bootstrap-estimate network: f_psi
        self.value_net_psi = Model(
            2 * self.latent_state_dim,      # input:  mu(s_t), sigma(s_t)
            self.n_actions,                 # output: Estimated EFE for all actions in actions space
            self.n_hidden_val,              # hidden layer
            lr = self.lr_val,               # learning rate
            device = self.device
        )

        # TARGET EFE bootstrap-estimate network: f_psi
        self.target_net = Model(
            2 * self.latent_state_dim,      # input:  mu(s_t), sigma(s_t)
            self.n_actions,                 # output: Estimated EFE for all actions in actions space
            self.n_hidden_val,              # hidden layer
            lr = self.lr_val,               # learnng rate
            device = self.device
        )

        self.target_net.load_state_dict(self.value_net_psi.state_dict())
            
        if self.load_network: # If true: load the networks given paths

            self.generative_transition_net.load_state_dict(torch.load(self.network_load_path.format("trans")))
            self.generative_transition_net.eval()

            self.generative_observation_net_xi.load_state_dict(torch.load(self.network_load_path.format("obs")))
            self.generative_observation_net_xi.eval()

            self.variational_transition_net.load_state_dict(torch.load(self.network_load_path.format("var_trans")))
            self.variational_transition_net.eval()

            self.policy_net_nu.load_state_dict(torch.load(self.network_load_path.format("pol"), map_location=self.device))
            self.policy_net_nu.eval()

            self.value_net_psi.load_state_dict(torch.load(self.network_load_path.format("val"), map_location=self.device))
            self.value_net_psi.eval()

        # Initialize the replay memory
        self.memory = ReplayMemory(self.memory_capacity, self.obs_shape, device=self.device)

        # When sampling from memory at index i, obs_indices indicates that we want observations with indices i-obs_indices, works the same for the others
        # self.obs_indices = [(self.n_screens+1)-i for i in range(self.n_screens+2)] # the third most recent, second most recent and most recent observation???
        self.obs_indices = [3, 2, 1, 0] # the third most recent, second most recent and most recent observation
        self.hidden_state_indices = [3, 2, 1, 0] # the third most recent, second most recent and most recent inferred hidden state
        # self.action_indices = [2, 1]
        self.action_indices = [3, 2, 1, 0]
        self.reward_indices = [1]
        # self.done_indices = [0]
        self.terminated_indices = [0] # I assume it's ok that these are the same?
        self.truncated_indices = [0] # I assume it's ok that these are the same?

        # PRIOR VERSION:
        self.max_n_indices = max(max(self.obs_indices, self.action_indices, self.reward_indices, self.terminated_indices, self.truncated_indices, self.hidden_state_indices)) + 1
        self.terminated_indices, self.truncated_indices, self.hidden_state_indices)) + 2       

Finally, here is the definition of my MVGaussianModel class, since this is the kind of network that the troublesome self.posterior_transition_net_phi instantiates:

class MVGaussianModel(nn.Module):

    def __init__(self, n_inputs, n_outputs, n_hidden, lr=1e-4, device='cpu', name = None):

        super(MVGaussianModel, self).__init__()

        self.name = name

        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.n_hidden = n_hidden

        self.fc1 = nn.Linear(n_inputs, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)

        self.mean_fc = nn.Linear(n_hidden, n_outputs)
        self.log_var = nn.Linear(n_hidden, n_outputs)

        self.optimizer = optim.Adam(self.parameters(), lr)
        self.device = device
        self.to(self.device)

    def forward(self, x):

        x_1 = torch.relu(self.fc1(x))
        x_2 = torch.relu(self.fc2(x_1))

        mean = self.mean_fc(x_2)
        log_var = self.log_var(x_2)

        return mean, log_var

    def rsample(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(mu)

        return mu + (eps * std)

I would be extremely grateful for any assistance, queries or suggestions. As I say, I am very happy to elaborate on any points at all. In an effort to be as comprehensive as possible, I enclose screenshots from a small latex document of mine, which details the structure of all the above networks. Perhaps that might be useful.

Many thanks!