EOFError when training A2C

I’m trying to train A2C model in SpaceInvadersNoFrameskip-v4 but when I try to get the trajectory I get the error and cannot figure out what might be the problem.

This is my model:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
  def __init__(self, num_actions, use_bn=False):
   super(DQN, self).__init__()
   self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
   self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
   self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
   self.fc1 = nn.Linear(in_features=64 * 7 * 7, out_features=512)
   self.actor = nn.Linear(in_features=512, out_features=num_actions)
   self.critic = nn.Linear(in_features=512, out_features=1)

   # initialize biases with zeros
   nn.init.constant_(self.fc1.bias, 0)
   nn.init.constant_(self.actor.bias, 0)
   nn.init.constant_(self.critic.bias, 0)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = F.relu(self.conv3(x))
    x = F.relu(self.fc1(x.view(x.size(0), -1)))
    value = self.critic(x)
    return self.actor(x.view(x.size(0), -1)), value

Policy class:

class Policy:
    def __init__(self, model):
        self.model = model

    def act(self, inputs):
        # Implement a policy by calling the model, sampling actions and computing their log probs.
        # Should return a dict containing keys ['actions', 'logits', 'log_probs', 'values'].

        inputs = torch.FloatTensor(inputs)
        print(inputs.shape)

        with torch.no_grad():
          logits, values = self.model(inputs)
          probs = np.array(F.softmax(logits, -1))
          log_probs = np.array(F.log_softmax(logits, -1))
        entropy = -(probs * log_probs).sum(-1).mean()

        actions = np.zeros((logits.shape[0],))
        for i in range(logits.shape[0]):
          actions[i] = np.random.choice(n_actions, p=probs[i])

        log_probs_for_actions = torch.sum(torch.Tensor(log_probs) * F.one_hot(torch.Tensor(actions).to(torch.int64), env.action_space.n), dim=1)

        return dict(actions=actions, logits=logits, log_probs=log_probs_for_actions, values=values, entropy=entropy)

Computing target values:

class ComputeValueTargets:
    def __init__(self, policy, gamma=0.99):
        self.policy = policy
        self.gamma = gamma

    def __call__(self, trajectory):
        """Compute value targets for a given partial trajectory."""

        # This method should modify trajectory inplace by adding
        # an item with key 'value_targets' to it.
        value_targets = []
        t = len(rewards)
        rewards = trajectory.get("rewards")
        resets = trajectory.get("resets")
        qa_values = trajectory.get("values")

        # need to use policy here to estimate some values with critic
        for i in range(t):
          value_target = 0
          for j in range(t):
            value_target += self.gamma ** j * rewards[i + j]
          if resets[i] != 0:
            value = self.policy.act(trajectory['state']['latest_observation'][i])['values']
            value_target += self.gamma ** t * value
          value_targets.append(value_target)

        trajectory['value_targets'] = value_targets

Initializing everything:

model = DQN(n_actions)
policy = Policy(model)
runner = EnvRunner(
    env=env,
    policy=policy,
    nsteps=5,
    transforms=[
        ComputeValueTargets(policy),
        MergeTimeBatch(),
    ],
)

I have this EnvRunner class that should get me the trajectory dict:

from collections import defaultdict

import numpy as np

class EnvRunner:
    """Reinforcement learning runner in an environment with given policy"""

    def __init__(self, env, policy, nsteps, transforms=None, step_var=None):
        self.env = env
        self.policy = policy
        self.nsteps = nsteps
        self.transforms = transforms or []
        self.step_var = step_var if step_var is not None else 0
        self.state = {"latest_observation": self.env.reset()[0]}

    @property
    def nenvs(self):
        """Returns number of batched envs or `None` if env is not batched"""
        return getattr(self.env.unwrapped, "nenvs", None)

    def reset(self, **kwargs):
        """Resets env and runner states."""
        self.state["latest_observation"] = self.env.reset(**kwargs)[0]
        self.policy.reset()

    def add_summary(self, name, val):
        """Writes logs"""
        add_summary = self.env.get_wrapper_attr("add_summary")
        add_summary(name, val)

    def get_next(self):
        """Runs the agent in the environment."""
        trajectory = defaultdict(list, {"actions": []})
        observations = []
        rewards = []
        resets = []
        self.state["env_steps"] = self.nsteps

        for i in range(self.nsteps):
            observations.append(self.state["latest_observation"])
            act = self.policy.act(self.state["latest_observation"])
            if "actions" not in act:
                raise ValueError(
                    "result of policy.act must contain 'actions' "
                    f"but has keys {list(act.keys())}"
                )
            for key, val in act.items():
                trajectory[key].append(val)

            obs, rew, terminated, truncated, _ = self.env.step(
                trajectory["actions"][-1]
            )
            self.state["latest_observation"] = obs
            rewards.append(rew)
            reset = np.logical_or(terminated, truncated)
            resets.append(reset)
            self.step_var += self.nenvs or 1

            # Only reset if the env is not batched. Batched envs should
            # auto-reset.
            if not self.nenvs and np.all(reset):
                self.state["env_steps"] = i + 1
                self.state["latest_observation"] = self.env.reset()[0]

        trajectory.update(observations=observations, rewards=rewards, resets=resets)
        trajectory["state"] = self.state

        for transform in self.transforms:
            transform(trajectory)
        return trajectory

And finally putting all together class:

class A2C:
    def __init__(self,
                 policy,
                 optimizer,
                 value_loss_coef=0.25,
                 entropy_coef=0.01,
                 max_grad_norm=0.5):
        self.policy = policy
        self.optimizer = optimizer
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm

    def policy_loss(self, trajectory):
        # You will need to compute advantages here.
        states = trajectory['observations']
        # d = ['actions', 'logits', 'log_probs', 'values']
        d = self.policy.act(states)
        true_values = trajectory['values']
        advantages = true_values - d['values']
        actor_loss = -(log_probs * advantages.detach()).mean()
        return actor_loss

    def value_loss(self, trajectory):
        states = trajectory['observations']
        d = self.policy.act(states)
        true_values = trajectory['values']
        advantages = true_values - d['values']
        critic_loss = advantages.pow(2).mean()
        return critic_loss

    def loss(self, trajectory):
        d = self.policy.act(trajectory['observations'])
        entropy = d['entropy']
        total_loss = (self.value_loss_coef * self.value_loss(trajectory)) + self.policy_loss(trajectory) - (self.entropy_coef * entropy)

    def step(self, trajectory):
        self.optimizer.zero_grad()
        self.loss(trajectory).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)
        self.optimizer.step()

Error happens when I try:

trajectory = runner.get_next()

Error message is this:

EOFError                                  Traceback (most recent call last)
<ipython-input-23-e42d189fa27f> in <cell line: 1>()
----> 1 trajectory = runner.get_next()

7 frames
<ipython-input-2-104fc603a9cd> in get_next(self)
     48                 trajectory[key].append(val)
     49 
---> 50             obs, rew, terminated, truncated, _ = self.env.step(
     51                 trajectory["actions"][-1]
     52             )

/usr/local/lib/python3.10/dist-packages/gymnasium/core.py in step(self, action)
    553     ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
    554         """Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`."""
--> 555         observation, reward, terminated, truncated, info = self.env.step(action)
    556         return observation, self.reward(reward), terminated, truncated, info
    557 

<ipython-input-3-29cc79dc072c> in step(self, action)
    282 
    283     def step(self, action):
--> 284         obs, rew, terminated, truncated, info = self.env.step(action)
    285         self.rewards += rew
    286         self.episode_lengths[~self.had_ended_episodes] += 1

/content/env_batch.py in step(self, actions)
    212         for conn, a in zip(self._parent_connections, actions):
    213             conn.send(("step", a))
--> 214         results = [conn.recv() for conn in self._parent_connections]
    215         obs, rews, terminated, truncated, infos = zip(*results)
    216         return (

/content/env_batch.py in <listcomp>(.0)
    212         for conn, a in zip(self._parent_connections, actions):
    213             conn.send(("step", a))
--> 214         results = [conn.recv() for conn in self._parent_connections]
    215         obs, rews, terminated, truncated, infos = zip(*results)
    216         return (

/usr/lib/python3.10/multiprocessing/connection.py in recv(self)
    248         self._check_closed()
    249         self._check_readable()
--> 250         buf = self._recv_bytes()
    251         return _ForkingPickler.loads(buf.getbuffer())
    252 

/usr/lib/python3.10/multiprocessing/connection.py in _recv_bytes(self, maxsize)
    412 
    413     def _recv_bytes(self, maxsize=None):
--> 414         buf = self._recv(4)
    415         size, = struct.unpack("!i", buf.getvalue())
    416         if size == -1:

/usr/lib/python3.10/multiprocessing/connection.py in _recv(self, size, read)
    381             if n == 0:
    382                 if remaining == size:
--> 383                     raise EOFError
    384                 else:
    385                     raise OSError("got end of file during message")

EOFError: 

Could somebody please give a hint as to why this might happen?

Thanks in advance.