[PettingZoo] Trouble running multiple MARL environments in parallel

Hi all,

I’m trying to use Multi-Agent SAC to train a MultiWalker-v9 model.

Since I am new to MARL AND TorchRL, I loosely followed the MADDPG/MAPPO tutorials as well as the SOTA implementation and just swapped out the parts as necessary.

NOTE: This is not a question about PettingZoo’s Parallel API, which works fine for me. I am trying to run multiple copies of the PettingZoo environment to speed up data collection.

This is my code:

from copy import deepcopy
import tqdm
import numpy as np
from gymnasium.spaces import Box

import logging

import torch
from torch import nn

from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage 
from torchrl.envs import (
    check_env_specs,
    PettingZooEnv, 
    ParallelEnv,
    GymEnv
)
from torchrl.modules import TanhNormal, AdditiveGaussianWrapper, ProbabilisticActor
from torchrl.modules.models import (
    MLP
)
from torchrl.modules.models.multiagent import (
    MultiAgentMLP,
    MultiAgentNetBase
)
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
from torchrl.objectives import SACLoss, SoftUpdate

from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs import EnvCreator, TransformedEnv

import multiprocessing as mp

class SMACCNet(MultiAgentNetBase): 
    '''
    https://pytorch.org/rl/main/_modules/torchrl/modules/models/multiagent.html 
    This is just a more limited version of MultiAgentMLP. 
    '''
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = None,
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )
    
    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        # If the model is centralized, agents have full observability
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents
        
        # Note to self: This is where you change the model architecture.
        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Linear(300, self.n_agent_outputs)
        ).to(device)

        return model

# ripped from StackOverflow
class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)  

# Main Function
if __name__ == "__main__":
    logging.basicConfig(level = logging.INFO)
    logger = logging.getLogger(__name__)
    logger.propagate = False
    logger.addHandler(TqdmLoggingHandler())

    mp.set_start_method("spawn", force = True)
    
    NUM_CRITICS = 2
    EXPLORATION_STEPS = 100 #30000
    MAX_EPISODE_STEPS = 10 #2000
    DEVICE = "cuda"
    REPLAY_BUFFER_SIZE = 1e3 #5e5
    VALUE_GAMMA = 0.99
    BATCH_SIZE = 32 #256
    EPS = 1e-7
    LR = 3e-4
    UPDATE_STEPS_PER_BATCH = 1 #750
    WARMUP_STEPS = 10 #int(1e5)
    TRAIN_TIMESTEPS = 5000 #int(2.5e6)
    EVAL_INTERVAL = 1 #int(5e5 // EXPLORATION_STEPS)

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    def env_fn(mode, parallel = True):
        base_env = PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = 3, 
                                    shared_reward = False, 
                                    max_cycles = 1000, 
                                    render_mode = mode, 
                                    device = "cpu"
                                )

        if parallel:
            env = ParallelEnv(num_workers = 4,  # noqa: E731
                                        create_env_fn = EnvCreator(lambda: base_env), 
                                        device = "cpu",
                                        mp_start_method = "spawn",
                                        serial_for_single = True
                                    )
        else:
            env = base_env # noqa: E731

        return lambda: env

    train_env = env_fn(None)()
    if train_env.is_closed:
        train_env.start()
    eval_env = env_fn("rgb_array", parallel = False)()

    check_env_specs(train_env)

    obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["walker", "action"].shape[-1]

    # NOTE: When using a ParallelEnv, you must call the attributes you want to access.
    # e.g. train_env.group_map(), not train_env.group_map
    agent_grpings = train_env.group_map()[0]
    num_agents = len(agent_grpings["walker"])
    # num_agents = len(train_env.group_map["walker"])

    # Construct the raw policy network (which acts in a decentralised manner).
    # NOTE: SAC is a probabilistic policy; here we ask that the network outputs the mean AND std
    # of the action space.
    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = num_agents,
                          centralised = False, # i.e. Not a joint policy.
                          share_params = True, # But agents act from the same playbook. This becomes less of a problem with communication.
                          device = DEVICE,
                          activation_class = nn.Tanh, 
                        ),
                        NormalParamExtractor(),
                    )
    
    # Construct the raw critic networks (which ARE in fact centralized).
    # As these are Q-functions, we need to add the action space too.
    # Apparently TorchRL already takes care of the multiple critic networks for you.
    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = num_agents,
                          centralised = True, # i.e. A joint policy.
                          share_params = True,
                          device = DEVICE,
                          activation_class = nn.Tanh, 
                        )

    # TorchRL's pipelines are VERY HEAVILY focused on performing transformations on dictionaries.
    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("walker", "observation")],
                                            # NOTE: These outputs must match with the parameter names of the 
                                            # distribution you are using!
                                            out_keys = [("walker", "loc"), ("walker", "scale")]
                                        )

    # I must confess I just copied this from the MADDPG setup.
    # However, if we are interested in using the same module code for the policy and critic networks,
    # this is probably a better solution than subclassing the MLP and saying you accept 2 inputs. 
    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("walker", "observation"), ("walker", "action")],
                                        out_keys = [("walker", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("walker", "obs_act")],
                                            out_keys = [("walker", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["walker", "action"],
        in_keys = [("walker", "loc"), ("walker", "scale")],
        out_keys = [("walker", "action")],
        distribution_class = TanhNormal,
        distribution_kwargs = {
            "min": train_env.full_action_spec["walker", "action"].space.low,
            "max": train_env.full_action_spec["walker", "action"].space.high,
        },
        return_log_prob = False,
    )
    # Hopefully there are no lazy layers, but just in case, wake them up.
    fake_td = train_env.fake_tensordict().to(DEVICE)
    policy_actor(fake_td)

    # Add some extra noise to SAC.
    # In case you ask "Why are you adding noise when there is already noise sampled by the ProbabilisticActor",
    # Doing so allows us to ensure a non-negligble *minimum* action noise for our model, which can help it explore better.
    dora = AdditiveGaussianWrapper(
        policy = policy_actor,
        action_key = ("walker", "action"),
        sigma_init = 0.3,
        sigma_end = 0.1,
        annealing_num_steps = 20000
    )

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        ) 
                    
    collector = SyncDataCollector( 
                    train_env,
                    policy = dora, # the explora
                    frames_per_batch = EXPLORATION_STEPS,
                    max_frames_per_traj = MAX_EPISODE_STEPS,
                    total_frames = TRAIN_TIMESTEPS,
                    device = "cpu",
                    reset_at_each_iter = False,
                )

    replay_buffer = TensorDictReplayBuffer(
        storage = LazyMemmapStorage(
            REPLAY_BUFFER_SIZE, device = "cpu",
        ),  # We will store up to memory_size multi-agent transitions
        sampler = RandomSampler(),
        batch_size = BATCH_SIZE,  # We will sample batches of this size
    )

    sac_loss = SACLoss(policy_actor, 
                        qvalue_network = critic_actor, 
                        num_qvalue_nets = 2,
                        loss_function = "l2",
                        delay_qvalue = True,
                        alpha_init = 0.1
                        )

    # Apparently TorchRL maintain their own default keybunch for the loss calculation
    # To avoid key mismatch, we must change them.
    sac_loss.set_keys(
        action = ("walker", "action"),
        state_action_value = ("walker", "state_action_value"),
        reward = ("walker", "reward"),
        done = ("walker", "done"),
        terminated = ("walker", "terminated"),
    )
    
    # TO be sure, we ARE using the Q-value-only formulation. Unfortunately, I don't know if this is necessary or not.
    # Let's leave it in for now.
    # My hunch is yes...? Because the gamma value is not yet supplied but nevertheless required for Q-learning updates.
    # https://pytorch.org/rl/stable/reference/generated/torchrl.objectives.SACLoss.html 
    # TODO: FIND OUT WHAT IS GOING ON HERE
    sac_loss.make_value_estimator(gamma = VALUE_GAMMA)

    polyak_updater = SoftUpdate(sac_loss, tau = 0.005) 

    critic_params = list(sac_loss.qvalue_network_params.flatten_keys().values())
    actor_params = list(sac_loss.actor_network_params.flatten_keys().values())

    optimizer_actor = torch.optim.Adam(
        actor_params,
        lr = LR,
        # weight_decay = 1e-4,
        eps = EPS,
    )
    optimizer_critic = torch.optim.Adam(
        critic_params,
        lr = LR,
        eps = EPS,
    )
    optimizer_alpha = torch.optim.Adam(
        [sac_loss.log_alpha],
        lr = 3e-4,
    )

    num_frames = 0
    pbar = tqdm.tqdm(total = TRAIN_TIMESTEPS)
    total_frames = 0
    ep_info = {'rewards': [], 'lengths': []}

    # Now everything is in place, write the training loop...
    breakpoint()
    for i, tensordict in enumerate(collector):

        collector.update_policy_weights_()

        pbar.update(tensordict.numel())

        tensordict = tensordict.reshape(-1)
        current_frames = tensordict.numel()
        # Add to replay buffer
        replay_buffer.extend(tensordict.cpu())
        total_frames += current_frames

        # Optimization steps
        if total_frames >= WARMUP_STEPS:
            losses = TensorDict({}, batch_size=[UPDATE_STEPS_PER_BATCH])
            for i in range(UPDATE_STEPS_PER_BATCH):
                # Sample from replay buffer
                sampled_tensordict = replay_buffer.sample()
                if sampled_tensordict.device != DEVICE:
                    sampled_tensordict = sampled_tensordict.to(
                        DEVICE, non_blocking=True
                    )
                else:
                    sampled_tensordict = sampled_tensordict.clone()

                try:
                    # Compute loss
                    loss_td = sac_loss(sampled_tensordict)
                except KeyError:
                    raise Exception(f"Check {sampled_tensordict}\n{obs_act_module(sampled_tensordict)['walker', 'obs_act']}")

                actor_loss = loss_td["loss_actor"]
                q_loss = loss_td["loss_qvalue"]
                alpha_loss = loss_td["loss_alpha"]

                # Update actor
                optimizer_actor.zero_grad()
                actor_loss.backward()
                optimizer_actor.step()

                # Update critic
                optimizer_critic.zero_grad()
                q_loss.backward()
                optimizer_critic.step()

                # Update alpha
                optimizer_alpha.zero_grad()
                alpha_loss.backward()
                optimizer_alpha.step()

                losses[i] = loss_td.select(
                    "loss_actor", "loss_qvalue", "loss_alpha"
                ).detach()

                # Update qnet_target params
                polyak_updater.step()

        if not ((i + 1) % EVAL_INTERVAL):
            with set_exploration_type(ExplorationType.MODE), torch.no_grad():
                eval_rollout = eval_env.rollout(
                    MAX_EPISODE_STEPS,
                    policy_actor,
                    auto_cast_to_device=True,
                    break_when_any_done=True,
                )

                eval_reward = eval_rollout["next", "walker", "reward"].sum(-2).mean().item()
                logger.info(f"Mean Reward: {eval_reward}")
                # logger.info("SUCCESS")
            ep_reward_list = []
    collector.shutdown()
    train_env.close()

Running this code without ParallelEnv proceeds without errors. However, with ParallelEnv:

  File "/home/n00bcak/Desktop/<path_to_file>/smacc.py", line 433, in <module>
    for i, tensordict in enumerate(collector):
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 888, in iterator
    tensordict_out = self.rollout()
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/_utils.py", line 480, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1007, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/common.py", line 2750, in step_and_maybe_reset
    tensordict_ = self.maybe_reset(tensordict_)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/common.py", line 2795, in maybe_reset
    tensordict = self.reset(tensordict)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/common.py", line 2120, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 809, in _reset
    tensordict_reset = self.base_env._reset(tensordict, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 60, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 1518, in _reset
    self.shared_tensordicts[i].apply_(
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/base.py", line 4289, in apply_
    return self.apply(fn, *others, inplace=True, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/base.py", line 4387, in apply
    result = self._apply_nest(
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/_td.py", line 988, in _apply_nest
    item_trsf = item._apply_nest(
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/_td.py", line 1013, in _apply_nest
    item_trsf = fn(item, *_others)
  File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 1515, in tentative_update
    val.copy_(other, non_blocking=self.non_blocking)
RuntimeError: output with shape [3, 1] doesn't match the broadcast shape [3, 3]

I have failed to pinpoint the source of the error after some time, although it appears to be due to a malformed tensordict which emerges from somewhere within torchRL’s source code.

Any insights on this issue are greatly appreciated.

P.S. I have alternatively tried to use MultiSyncDataCollector:

    collector = MultiSyncDataCollector(
                    [EnvCreator(lambda: train_env)] * 8, # train_env now ONLY uses PettingZoo's Parallel API
                    policy = dora.to('cpu'), # the explora
                    frames_per_batch = EXPLORATION_STEPS,
                    max_frames_per_traj = MAX_EPISODE_STEPS,
                    total_frames = TRAIN_TIMESTEPS,
                    device = "cpu",
                    policy_device = "cpu",
                    storing_device = "cpu",
                    env_device = "cpu",
                    reset_at_each_iter = False
                )

Which results in the following error:

Traceback (most recent call last):
  File "/home/n00bcak/Desktop/<path_to_script>/smacc.py", line 355, in <module>
    collector = MultiSyncDataCollector(
  File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torchrl/collectors/collectors.py", line 1514, in __init__
    self._run_processes()
  File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torchrl/collectors/collectors.py", line 1672, in _run_processes
    proc.start()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/usr/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/usr/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/usr/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torch/multiprocessing/reductions.py", line 568, in reduce_storage
    fd, size = storage._share_fd_cpu_()
  File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torch/storage.py", line 304, in wrapper
    return fn(self, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torch/storage.py", line 374, in _share_fd_cpu_
    return super()._share_fd_cpu_(*args, **kwargs)
RuntimeError: _share_fd_: only available on CPU

Any insights on this bug are appreciated as well!

At a first glance the problem could be stemming from here, the lambda should create the env, not just point to the existing one, otherwise multiple envs won’t be created.

Thank you. I can’t believe I didn’t notice that when debugging ^^;

However, changing the env_fn to:

    def env_fn(mode, parallel = True):
        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = 3, 
                                    shared_reward = False, 
                                    max_cycles = 1000, 
                                    render_mode = mode, 
                                    device = "cpu"
                                )

        if parallel:
            env = lambda: ParallelEnv(num_workers = 4,  # noqa: E731
                                        create_env_fn = EnvCreator(base_env_fn), 
                                        device = "cpu",
                                        mp_start_method = "spawn",
                                        serial_for_single = True
                                    )
        else:
            env = base_env_fn # noqa: E731

        return env

and the data collector to:

    collector = SyncDataCollector( 
                    env_fn(None),
                    policy = dora,
                    frames_per_batch = EXPLORATION_STEPS,
                    max_frames_per_traj = MAX_EPISODE_STEPS,
                    total_frames = TRAIN_TIMESTEPS,
                    device = "cpu",
                    reset_at_each_iter = False,
                )

results in the same bug. I do not believe this is the cause of the RuntimeError.

P.S. Calling EnvCreator(env_fn(None)) (which wraps a ParallelEnv) results in the following error:

Traceback (most recent call last):
  File "/home/n00bcak/Desktop/drones_go_brr/scripts/smacc.py", line 344, in <module>
    collector = SyncDataCollector( 
  File "/home/n00bcak/Desktop/programming/venvs/thales/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 455, in __init__
    env = create_env_fn(**create_env_kwargs)
  File "/home/n00bcak/Desktop/programming/venvs/thales/lib/python3.10/site-packages/torchrl/envs/env_creator.py", line 142, in __call__
    env.load_state_dict(self._transform_state_dict, strict=False)
  File "/home/n00bcak/Desktop/programming/venvs/thales/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 60, in decorated_fun
    return fun(self, *args, **kwargs)
TypeError: ParallelEnv.load_state_dict() got an unexpected keyword argument 'strict'

This seems to be an unexpected bug, and not wrapping the function in EnvCreator resolves the error.

I see.

I have simplified your code to the core bug.
It seems a bug in the resetting of ParallelEnv

from torchrl.envs import PettingZooEnv, ParallelEnv, SerialEnv
from torchrl.collectors import SyncDataCollector

if __name__ == "__main__":

    def base_env_fn():
        return PettingZooEnv(
            task="multiwalker_v9",
            parallel=True,
            seed=42,
            n_walkers=3,
            shared_reward=False,
            max_cycles=1000,
            render_mode=None,
            device="cpu",
        )

    collector = SyncDataCollector(
        lambda: ParallelEnv(
            num_workers=4,  # noqa: E731
            create_env_fn=base_env_fn,
            device="cpu",
        ),
        policy=None,  # the explora
        frames_per_batch=100,
        max_frames_per_traj=50,
        total_frames=200,
        device="cpu",
        reset_at_each_iter=False,
    )
    try:
        for i, tensordict in enumerate(collector):
            print("Reached")
    finally:
        collector.shutdown()

I suggest opening an issue in the torchrl repo with this script and a description and we can find a solution from there

I believe this will fix it

SFLR, I have opened github issues relating to the mentioned bugs.

In each I have pinpointed a possible source of the bug from what I can see.

I would agree with @matteobettini on the root cause of the bug:

Stepping through the code shows that that is the exact block where the _reset key emerges misshapen, as the wrong batch_size is taken as reference.