Hello there.
I am attempting to create an “Active Inference” agent by means of several neural network models. Active Inference is a similar paradigm to Reinforcement Learning, so I thought it best to post here.
There is a very large amount of context that might be relevant to my queries, but I will be as terse as possible here. I am more than happy to elaborate on anything, please do ask about anything at all.
In essence, I want to blend the use of deep active inference and heuristic tree search for planning. As an intermediate step toward this end, I simply want to blend the approaches found in these two papers: https://www.frontiersin.org/articles/10.3389/fncom.2020.574372/full and [2009.03622] Deep Active Inference for Partially Observable MDPs.
I am using the CartPole-v1 environment: Cart Pole - Gymnasium Documentation Hence the action space has carnality 2 and the state space is 4-dimensional. My code is a heavily modified version of that found in the repo from the second of the above papers.
The first paper uses various Gaussian neural networks to approximate the agent’s generative model. These networks can then be trained on a free energy loss function, thus instantiating variational inference of the kind used in the “Active Inference” paradigm. The second paper finesses the issue of evaluating future actions via a bootstrapped estimate of the Expected Free Energy. This allows the agent to “plan” without explicitly maintaining a search tree over counterfactual trajectories.
I am encountering a very simple error in the forward operation of one of my neural networks. I don’t think the error has anything to do with the theory of Active Inference per se, so it should be amenable to anyone with sufficient knowledge of PyTorch - it seems that this is not me. Basically, after some period of time, one of my networks starts to produce NaN outputs. Here is the exact output:
Traceback (most recent call last):
File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 2025, in <module>
agent.train_models()
File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 1755, in train_models
VFE, value_net_psi_loss, expected_log_ev = self.learn()
File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 1481, in learn
expected_log_ev = self.mc_expected_log_evidence(z_batch_t1)
File "/home/...ai_pomdp_agent_ACT_No_VAE.py", line 1315, in mc_expected_log_evidence
multivariate_normal_p = torch.distributions.MultivariateNormal(
File "/home/...multivariate_normal.py", line 146, in __init__
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
File "/home/...distribution.py", line 55, in __init__
raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (32, 4)) of distribution MultivariateNormal(loc: torch.Size([32, 4]), covariance_matrix: torch.Size([32, 4, 4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]], grad_fn=<ExpandBackward0>)
Now I’ve traced the source of this error all the way back to my get_mini_batches function:
def get_mini_batches(self):
all_obs_batch, all_actions_batch, reward_batch_t1, terminated_batch_t2, truncated_batch_t2, hidden_state_batch = self.memory.sample(
self.obs_indices, self.action_indices, self.reward_indices,
self.terminated_indices, self.truncated_indices, self.hidden_state_indices,
self.max_n_indices, self.batch_size
)
# Retrieve a batch of inferred hidden states for 3 consecutive points in time
inferred_state_batch_t0 = hidden_state_batch[:, 0].view([self.batch_size] + [dim for dim in self.obs_shape])
inferred_state_batch_t1 = hidden_state_batch[:, 1].view([self.batch_size] + [dim for dim in self.obs_shape])
inferred_state_batch_t2 = hidden_state_batch[:, 2].view([self.batch_size] + [dim for dim in self.obs_shape])
# Retrieve a batch of observations, hidden states, and actions for consecutive points in time
# obs_batch_t0 = all_obs_batch[:, 0, :].view(self.batch_size, -1) # Most recent observation
obs_batch_t1 = all_obs_batch[:, 1, :].view(self.batch_size, -1) # Second most recent observation
obs_batch_t2 = all_obs_batch[:, 2, :].view(self.batch_size, -1) # Third most recent observation
obs_batch_t3 = all_obs_batch[:, 3, :].view(self.batch_size, -1) # Fourth most recent observation
# Retrieve the agent's action history for time t0, t1, t2 and t3
action_batch_t0 = all_actions_batch[:, 0].unsqueeze(1)
action_batch_t1 = all_actions_batch[:, 1].unsqueeze(1)
action_batch_t2 = all_actions_batch[:, 2].unsqueeze(1)
action_batch_t3 = all_actions_batch[:, 3].unsqueeze(1)
q_phi_inputs_t0 = torch.cat((inferred_state_batch_t0, action_batch_t1, obs_batch_t1), dim = 1) # s_t, a_{t + 1}, o_{t + 1}
q_phi_inputs_t1 = torch.cat((inferred_state_batch_t1, action_batch_t2, obs_batch_t2), dim = 1) # s_{t + 1}, a_{t + 2}, o_{t + 2}
q_phi_inputs_t2 = torch.cat((inferred_state_batch_t2, action_batch_t3, obs_batch_t3), dim = 1) # s_{t + 2}, a_{t + 3}, o_{t + 3}
# Retrieve a batch of distributions over states
state_mu_batch_t0, state_logvar_batch_t0 = self.posterior_transition_net_phi(q_phi_inputs_t0) # \mu{s_t}, \log{\Sigma^2(s_t)}
state_mu_batch_t1, state_logvar_batch_t1 = self.posterior_transition_net_phi(q_phi_inputs_t1) # \mu{s_{t + 1}}, \log{\Sigma^2(s_{t + 1})}
state_mu_batch_t2, state_logvar_batch_t2 = self.posterior_transition_net_phi(q_phi_inputs_t2) # \mu{s_{t + 2}}, \log{\Sigma^2(s_{t + 2})}
# Reparameterize the distribution over states for time t0 and t1
z_batch_t0 = self.posterior_transition_net_phi.rsample(state_mu_batch_t0, state_logvar_batch_t0)
z_batch_t1 = self.posterior_transition_net_phi.rsample(state_mu_batch_t1, state_logvar_batch_t1)
# At time t0 predict the state at time t1:
X = torch.cat((state_mu_batch_t0.detach(), action_batch_t0.float()), dim = 1)
pred_batch_mean_t0t1, pred_batch_logvar_t0t1 = self.prior_transition_net_theta(X)
# Determine the prediction error wrt time t0-t1 using state KL Divergence:
pred_error_batch_t0t1 = torch.sum(
self.gaussian_kl_div(
pred_batch_mean_t0t1, torch.exp(pred_batch_logvar_t0t1),
state_mu_batch_t1, torch.exp(state_logvar_batch_t1)
), dim=1
).unsqueeze(1)
print(f"\n\nget_mini_batches - state_mu_batch_t0:\n{state_mu_batch_t0}")
print(f"get_mini_batches - state_logvar_batch_t0:\n{state_logvar_batch_t0}")
print(f"get_mini_batches - state_mu_batch_t1:\n{state_mu_batch_t1}")
print(f"get_mini_batches - state_logvar_batch_t1:\n{state_logvar_batch_t1}\n\n")
return (
state_mu_batch_t1, state_logvar_batch_t1,
state_mu_batch_t2, state_logvar_batch_t2,
action_batch_t1, reward_batch_t1,
terminated_batch_t2, truncated_batch_t2, pred_error_batch_t0t1,
obs_batch_t1, state_mu_batch_t1,
state_logvar_batch_t1, z_batch_t0, z_batch_t1
)
For which the above print statements furnish:
get_mini_batches - state_mu_batch_t0:
tensor([[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
get_mini_batches - state_logvar_batch_t0:
tensor([[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
get_mini_batches - state_mu_batch_t1:
tensor([[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
get_mini_batches - state_logvar_batch_t1:
tensor([[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
If would appear that my self.posterior_transition_net_phi
is the culprit. I’ve checked the inputs to self.posterior_transition_net_phi
, none of them ever seem problematic; there are no inf or NaN values in the inputs.
For the sake of completeness, here is the constructor to my “Agent” class. This is where I initialise the self.posterior_transition_net_phi
network:
class Agent():
def __init__(self, argv):
self.set_parameters(argv) # Set parameters
# self.obs_shape = (4,)
self.obs_shape = self.env.observation_space.shape # The shape of observations
self.obs_size = np.prod(self.obs_shape) # The size of the observation
self.n_actions = self.env.action_space.n # The number of actions available to the agent
self.all_actions = [0, 1]
self.freeze_cntr = 0 # Keeps track of when to (un)freeze the target network
# specify the goal prior distribution: p_C:
self.goal_mu = torch.tensor([0, 0, 0, 0]).to(self.device)
self.goal_var = torch.tensor([1, 1, 1, 1]).to(self.device)
# Generative state prior: P_theta
self.prior_transition_net_theta = MVGaussianModel(
self.latent_state_dim + 1, # inputs: s_{t-1}, a_{t-1}
self.latent_state_dim, # outputs: mu(s_t), sigma(s_t)
self.n_hidden_gen_trans, # hidden layer
lr = self.lr_gen_trans, # learning rate
device = self.device,
name = 'prior_theta'
)
# Variational state posterior: Q_phi
self.posterior_transition_net_phi = MVGaussianModel(
2 * self.latent_state_dim + 1, # inputs: s_{t-1}, a_{t-1}, o_t
self.latent_state_dim, # outputs: mu(s_t), sigma(s_t)
self.n_hidden_var_trans, # hidden layer
lr = self.lr_var_trans, # learning rate
device = self.device,
name = 'posterior_phi'
)
# Generative observation likelihood: P_xi
self.generative_observation_net_xi = MVGaussianModel(
self.latent_state_dim, # input: s_t (reparam'd state sample)
self.latent_state_dim, # outputs: mu(o_t), sigma(o_t)
self.n_hidden_gen_obs, # hidden layer
lr = self.lr_gen_obs, # learning rate
device = self.device,
name = 'obs_xi'
)
# Variational Policy prior: Q_nu
self.policy_net_nu = Model(
2 * self.latent_state_dim, # inputs: mu(s_t), sigma(s_t)
self.n_actions, # output: categorical action probs for all actions in action space
self.n_hidden_pol, # hidden layer
lr = self.lr_pol, # learning rate
softmax = True,
device = self.device
)
# EFE bootstrap-estimate network: f_psi
self.value_net_psi = Model(
2 * self.latent_state_dim, # input: mu(s_t), sigma(s_t)
self.n_actions, # output: Estimated EFE for all actions in actions space
self.n_hidden_val, # hidden layer
lr = self.lr_val, # learning rate
device = self.device
)
# TARGET EFE bootstrap-estimate network: f_psi
self.target_net = Model(
2 * self.latent_state_dim, # input: mu(s_t), sigma(s_t)
self.n_actions, # output: Estimated EFE for all actions in actions space
self.n_hidden_val, # hidden layer
lr = self.lr_val, # learnng rate
device = self.device
)
self.target_net.load_state_dict(self.value_net_psi.state_dict())
if self.load_network: # If true: load the networks given paths
self.generative_transition_net.load_state_dict(torch.load(self.network_load_path.format("trans")))
self.generative_transition_net.eval()
self.generative_observation_net_xi.load_state_dict(torch.load(self.network_load_path.format("obs")))
self.generative_observation_net_xi.eval()
self.variational_transition_net.load_state_dict(torch.load(self.network_load_path.format("var_trans")))
self.variational_transition_net.eval()
self.policy_net_nu.load_state_dict(torch.load(self.network_load_path.format("pol"), map_location=self.device))
self.policy_net_nu.eval()
self.value_net_psi.load_state_dict(torch.load(self.network_load_path.format("val"), map_location=self.device))
self.value_net_psi.eval()
# Initialize the replay memory
self.memory = ReplayMemory(self.memory_capacity, self.obs_shape, device=self.device)
# When sampling from memory at index i, obs_indices indicates that we want observations with indices i-obs_indices, works the same for the others
# self.obs_indices = [(self.n_screens+1)-i for i in range(self.n_screens+2)] # the third most recent, second most recent and most recent observation???
self.obs_indices = [3, 2, 1, 0] # the third most recent, second most recent and most recent observation
self.hidden_state_indices = [3, 2, 1, 0] # the third most recent, second most recent and most recent inferred hidden state
# self.action_indices = [2, 1]
self.action_indices = [3, 2, 1, 0]
self.reward_indices = [1]
# self.done_indices = [0]
self.terminated_indices = [0] # I assume it's ok that these are the same?
self.truncated_indices = [0] # I assume it's ok that these are the same?
# PRIOR VERSION:
self.max_n_indices = max(max(self.obs_indices, self.action_indices, self.reward_indices, self.terminated_indices, self.truncated_indices, self.hidden_state_indices)) + 1
self.terminated_indices, self.truncated_indices, self.hidden_state_indices)) + 2
Finally, here is the definition of my MVGaussianModel
class, since this is the kind of network that the troublesome self.posterior_transition_net_phi
instantiates:
class MVGaussianModel(nn.Module):
def __init__(self, n_inputs, n_outputs, n_hidden, lr=1e-4, device='cpu', name = None):
super(MVGaussianModel, self).__init__()
self.name = name
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.n_hidden = n_hidden
self.fc1 = nn.Linear(n_inputs, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_hidden)
self.mean_fc = nn.Linear(n_hidden, n_outputs)
self.log_var = nn.Linear(n_hidden, n_outputs)
self.optimizer = optim.Adam(self.parameters(), lr)
self.device = device
self.to(self.device)
def forward(self, x):
x_1 = torch.relu(self.fc1(x))
x_2 = torch.relu(self.fc2(x_1))
mean = self.mean_fc(x_2)
log_var = self.log_var(x_2)
return mean, log_var
def rsample(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(mu)
return mu + (eps * std)
I would be extremely grateful for any assistance, queries or suggestions. As I say, I am very happy to elaborate on any points at all. In an effort to be as comprehensive as possible, I enclose screenshots from a small latex document of mine, which details the structure of all the above networks. Perhaps that might be useful.
Many thanks!