I’m trying to train A2C model in SpaceInvadersNoFrameskip-v4 but when I try to get the trajectory I get the error and cannot figure out what might be the problem.
This is my model:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DQN(nn.Module):
def __init__(self, num_actions, use_bn=False):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(in_features=64 * 7 * 7, out_features=512)
self.actor = nn.Linear(in_features=512, out_features=num_actions)
self.critic = nn.Linear(in_features=512, out_features=1)
# initialize biases with zeros
nn.init.constant_(self.fc1.bias, 0)
nn.init.constant_(self.actor.bias, 0)
nn.init.constant_(self.critic.bias, 0)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.fc1(x.view(x.size(0), -1)))
value = self.critic(x)
return self.actor(x.view(x.size(0), -1)), value
Policy class:
class Policy:
def __init__(self, model):
self.model = model
def act(self, inputs):
# Implement a policy by calling the model, sampling actions and computing their log probs.
# Should return a dict containing keys ['actions', 'logits', 'log_probs', 'values'].
inputs = torch.FloatTensor(inputs)
print(inputs.shape)
with torch.no_grad():
logits, values = self.model(inputs)
probs = np.array(F.softmax(logits, -1))
log_probs = np.array(F.log_softmax(logits, -1))
entropy = -(probs * log_probs).sum(-1).mean()
actions = np.zeros((logits.shape[0],))
for i in range(logits.shape[0]):
actions[i] = np.random.choice(n_actions, p=probs[i])
log_probs_for_actions = torch.sum(torch.Tensor(log_probs) * F.one_hot(torch.Tensor(actions).to(torch.int64), env.action_space.n), dim=1)
return dict(actions=actions, logits=logits, log_probs=log_probs_for_actions, values=values, entropy=entropy)
Computing target values:
class ComputeValueTargets:
def __init__(self, policy, gamma=0.99):
self.policy = policy
self.gamma = gamma
def __call__(self, trajectory):
"""Compute value targets for a given partial trajectory."""
# This method should modify trajectory inplace by adding
# an item with key 'value_targets' to it.
value_targets = []
t = len(rewards)
rewards = trajectory.get("rewards")
resets = trajectory.get("resets")
qa_values = trajectory.get("values")
# need to use policy here to estimate some values with critic
for i in range(t):
value_target = 0
for j in range(t):
value_target += self.gamma ** j * rewards[i + j]
if resets[i] != 0:
value = self.policy.act(trajectory['state']['latest_observation'][i])['values']
value_target += self.gamma ** t * value
value_targets.append(value_target)
trajectory['value_targets'] = value_targets
Initializing everything:
model = DQN(n_actions)
policy = Policy(model)
runner = EnvRunner(
env=env,
policy=policy,
nsteps=5,
transforms=[
ComputeValueTargets(policy),
MergeTimeBatch(),
],
)
I have this EnvRunner
class that should get me the trajectory dict:
from collections import defaultdict
import numpy as np
class EnvRunner:
"""Reinforcement learning runner in an environment with given policy"""
def __init__(self, env, policy, nsteps, transforms=None, step_var=None):
self.env = env
self.policy = policy
self.nsteps = nsteps
self.transforms = transforms or []
self.step_var = step_var if step_var is not None else 0
self.state = {"latest_observation": self.env.reset()[0]}
@property
def nenvs(self):
"""Returns number of batched envs or `None` if env is not batched"""
return getattr(self.env.unwrapped, "nenvs", None)
def reset(self, **kwargs):
"""Resets env and runner states."""
self.state["latest_observation"] = self.env.reset(**kwargs)[0]
self.policy.reset()
def add_summary(self, name, val):
"""Writes logs"""
add_summary = self.env.get_wrapper_attr("add_summary")
add_summary(name, val)
def get_next(self):
"""Runs the agent in the environment."""
trajectory = defaultdict(list, {"actions": []})
observations = []
rewards = []
resets = []
self.state["env_steps"] = self.nsteps
for i in range(self.nsteps):
observations.append(self.state["latest_observation"])
act = self.policy.act(self.state["latest_observation"])
if "actions" not in act:
raise ValueError(
"result of policy.act must contain 'actions' "
f"but has keys {list(act.keys())}"
)
for key, val in act.items():
trajectory[key].append(val)
obs, rew, terminated, truncated, _ = self.env.step(
trajectory["actions"][-1]
)
self.state["latest_observation"] = obs
rewards.append(rew)
reset = np.logical_or(terminated, truncated)
resets.append(reset)
self.step_var += self.nenvs or 1
# Only reset if the env is not batched. Batched envs should
# auto-reset.
if not self.nenvs and np.all(reset):
self.state["env_steps"] = i + 1
self.state["latest_observation"] = self.env.reset()[0]
trajectory.update(observations=observations, rewards=rewards, resets=resets)
trajectory["state"] = self.state
for transform in self.transforms:
transform(trajectory)
return trajectory
And finally putting all together class:
class A2C:
def __init__(self,
policy,
optimizer,
value_loss_coef=0.25,
entropy_coef=0.01,
max_grad_norm=0.5):
self.policy = policy
self.optimizer = optimizer
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
def policy_loss(self, trajectory):
# You will need to compute advantages here.
states = trajectory['observations']
# d = ['actions', 'logits', 'log_probs', 'values']
d = self.policy.act(states)
true_values = trajectory['values']
advantages = true_values - d['values']
actor_loss = -(log_probs * advantages.detach()).mean()
return actor_loss
def value_loss(self, trajectory):
states = trajectory['observations']
d = self.policy.act(states)
true_values = trajectory['values']
advantages = true_values - d['values']
critic_loss = advantages.pow(2).mean()
return critic_loss
def loss(self, trajectory):
d = self.policy.act(trajectory['observations'])
entropy = d['entropy']
total_loss = (self.value_loss_coef * self.value_loss(trajectory)) + self.policy_loss(trajectory) - (self.entropy_coef * entropy)
def step(self, trajectory):
self.optimizer.zero_grad()
self.loss(trajectory).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)
self.optimizer.step()
Error happens when I try:
trajectory = runner.get_next()
Error message is this:
EOFError Traceback (most recent call last)
<ipython-input-23-e42d189fa27f> in <cell line: 1>()
----> 1 trajectory = runner.get_next()
7 frames
<ipython-input-2-104fc603a9cd> in get_next(self)
48 trajectory[key].append(val)
49
---> 50 obs, rew, terminated, truncated, _ = self.env.step(
51 trajectory["actions"][-1]
52 )
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py in step(self, action)
553 ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
554 """Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`."""
--> 555 observation, reward, terminated, truncated, info = self.env.step(action)
556 return observation, self.reward(reward), terminated, truncated, info
557
<ipython-input-3-29cc79dc072c> in step(self, action)
282
283 def step(self, action):
--> 284 obs, rew, terminated, truncated, info = self.env.step(action)
285 self.rewards += rew
286 self.episode_lengths[~self.had_ended_episodes] += 1
/content/env_batch.py in step(self, actions)
212 for conn, a in zip(self._parent_connections, actions):
213 conn.send(("step", a))
--> 214 results = [conn.recv() for conn in self._parent_connections]
215 obs, rews, terminated, truncated, infos = zip(*results)
216 return (
/content/env_batch.py in <listcomp>(.0)
212 for conn, a in zip(self._parent_connections, actions):
213 conn.send(("step", a))
--> 214 results = [conn.recv() for conn in self._parent_connections]
215 obs, rews, terminated, truncated, infos = zip(*results)
216 return (
/usr/lib/python3.10/multiprocessing/connection.py in recv(self)
248 self._check_closed()
249 self._check_readable()
--> 250 buf = self._recv_bytes()
251 return _ForkingPickler.loads(buf.getbuffer())
252
/usr/lib/python3.10/multiprocessing/connection.py in _recv_bytes(self, maxsize)
412
413 def _recv_bytes(self, maxsize=None):
--> 414 buf = self._recv(4)
415 size, = struct.unpack("!i", buf.getvalue())
416 if size == -1:
/usr/lib/python3.10/multiprocessing/connection.py in _recv(self, size, read)
381 if n == 0:
382 if remaining == size:
--> 383 raise EOFError
384 else:
385 raise OSError("got end of file during message")
EOFError:
Could somebody please give a hint as to why this might happen?
Thanks in advance.