Hi all,
I’m trying to use Multi-Agent SAC to train a MultiWalker-v9 model.
Since I am new to MARL AND TorchRL, I loosely followed the MADDPG/MAPPO tutorials as well as the SOTA implementation and just swapped out the parts as necessary.
NOTE: This is not a question about PettingZoo’s Parallel API, which works fine for me. I am trying to run multiple copies of the PettingZoo environment to speed up data collection.
This is my code:
from copy import deepcopy
import tqdm
import numpy as np
from gymnasium.spaces import Box
import logging
import torch
from torch import nn
from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import (
check_env_specs,
PettingZooEnv,
ParallelEnv,
GymEnv
)
from torchrl.modules import TanhNormal, AdditiveGaussianWrapper, ProbabilisticActor
from torchrl.modules.models import (
MLP
)
from torchrl.modules.models.multiagent import (
MultiAgentMLP,
MultiAgentNetBase
)
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
from torchrl.objectives import SACLoss, SoftUpdate
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs import EnvCreator, TransformedEnv
import multiprocessing as mp
class SMACCNet(MultiAgentNetBase):
'''
https://pytorch.org/rl/main/_modules/torchrl/modules/models/multiagent.html
This is just a more limited version of MultiAgentMLP.
'''
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = None,
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
# If the model is centralized, agents have full observability
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
# Note to self: This is where you change the model architecture.
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Linear(300, self.n_agent_outputs)
).to(device)
return model
# ripped from StackOverflow
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
# Main Function
if __name__ == "__main__":
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(TqdmLoggingHandler())
mp.set_start_method("spawn", force = True)
NUM_CRITICS = 2
EXPLORATION_STEPS = 100 #30000
MAX_EPISODE_STEPS = 10 #2000
DEVICE = "cuda"
REPLAY_BUFFER_SIZE = 1e3 #5e5
VALUE_GAMMA = 0.99
BATCH_SIZE = 32 #256
EPS = 1e-7
LR = 3e-4
UPDATE_STEPS_PER_BATCH = 1 #750
WARMUP_STEPS = 10 #int(1e5)
TRAIN_TIMESTEPS = 5000 #int(2.5e6)
EVAL_INTERVAL = 1 #int(5e5 // EXPLORATION_STEPS)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
def env_fn(mode, parallel = True):
base_env = PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = 3,
shared_reward = False,
max_cycles = 1000,
render_mode = mode,
device = "cpu"
)
if parallel:
env = ParallelEnv(num_workers = 4, # noqa: E731
create_env_fn = EnvCreator(lambda: base_env),
device = "cpu",
mp_start_method = "spawn",
serial_for_single = True
)
else:
env = base_env # noqa: E731
return lambda: env
train_env = env_fn(None)()
if train_env.is_closed:
train_env.start()
eval_env = env_fn("rgb_array", parallel = False)()
check_env_specs(train_env)
obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
action_dim = train_env.full_action_spec["walker", "action"].shape[-1]
# NOTE: When using a ParallelEnv, you must call the attributes you want to access.
# e.g. train_env.group_map(), not train_env.group_map
agent_grpings = train_env.group_map()[0]
num_agents = len(agent_grpings["walker"])
# num_agents = len(train_env.group_map["walker"])
# Construct the raw policy network (which acts in a decentralised manner).
# NOTE: SAC is a probabilistic policy; here we ask that the network outputs the mean AND std
# of the action space.
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = num_agents,
centralised = False, # i.e. Not a joint policy.
share_params = True, # But agents act from the same playbook. This becomes less of a problem with communication.
device = DEVICE,
activation_class = nn.Tanh,
),
NormalParamExtractor(),
)
# Construct the raw critic networks (which ARE in fact centralized).
# As these are Q-functions, we need to add the action space too.
# Apparently TorchRL already takes care of the multiple critic networks for you.
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = num_agents,
centralised = True, # i.e. A joint policy.
share_params = True,
device = DEVICE,
activation_class = nn.Tanh,
)
# TorchRL's pipelines are VERY HEAVILY focused on performing transformations on dictionaries.
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("walker", "observation")],
# NOTE: These outputs must match with the parameter names of the
# distribution you are using!
out_keys = [("walker", "loc"), ("walker", "scale")]
)
# I must confess I just copied this from the MADDPG setup.
# However, if we are interested in using the same module code for the policy and critic networks,
# this is probably a better solution than subclassing the MLP and saying you accept 2 inputs.
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("walker", "observation"), ("walker", "action")],
out_keys = [("walker", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("walker", "obs_act")],
out_keys = [("walker", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["walker", "action"],
in_keys = [("walker", "loc"), ("walker", "scale")],
out_keys = [("walker", "action")],
distribution_class = TanhNormal,
distribution_kwargs = {
"min": train_env.full_action_spec["walker", "action"].space.low,
"max": train_env.full_action_spec["walker", "action"].space.high,
},
return_log_prob = False,
)
# Hopefully there are no lazy layers, but just in case, wake them up.
fake_td = train_env.fake_tensordict().to(DEVICE)
policy_actor(fake_td)
# Add some extra noise to SAC.
# In case you ask "Why are you adding noise when there is already noise sampled by the ProbabilisticActor",
# Doing so allows us to ensure a non-negligble *minimum* action noise for our model, which can help it explore better.
dora = AdditiveGaussianWrapper(
policy = policy_actor,
action_key = ("walker", "action"),
sigma_init = 0.3,
sigma_end = 0.1,
annealing_num_steps = 20000
)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
collector = SyncDataCollector(
train_env,
policy = dora, # the explora
frames_per_batch = EXPLORATION_STEPS,
max_frames_per_traj = MAX_EPISODE_STEPS,
total_frames = TRAIN_TIMESTEPS,
device = "cpu",
reset_at_each_iter = False,
)
replay_buffer = TensorDictReplayBuffer(
storage = LazyMemmapStorage(
REPLAY_BUFFER_SIZE, device = "cpu",
), # We will store up to memory_size multi-agent transitions
sampler = RandomSampler(),
batch_size = BATCH_SIZE, # We will sample batches of this size
)
sac_loss = SACLoss(policy_actor,
qvalue_network = critic_actor,
num_qvalue_nets = 2,
loss_function = "l2",
delay_qvalue = True,
alpha_init = 0.1
)
# Apparently TorchRL maintain their own default keybunch for the loss calculation
# To avoid key mismatch, we must change them.
sac_loss.set_keys(
action = ("walker", "action"),
state_action_value = ("walker", "state_action_value"),
reward = ("walker", "reward"),
done = ("walker", "done"),
terminated = ("walker", "terminated"),
)
# TO be sure, we ARE using the Q-value-only formulation. Unfortunately, I don't know if this is necessary or not.
# Let's leave it in for now.
# My hunch is yes...? Because the gamma value is not yet supplied but nevertheless required for Q-learning updates.
# https://pytorch.org/rl/stable/reference/generated/torchrl.objectives.SACLoss.html
# TODO: FIND OUT WHAT IS GOING ON HERE
sac_loss.make_value_estimator(gamma = VALUE_GAMMA)
polyak_updater = SoftUpdate(sac_loss, tau = 0.005)
critic_params = list(sac_loss.qvalue_network_params.flatten_keys().values())
actor_params = list(sac_loss.actor_network_params.flatten_keys().values())
optimizer_actor = torch.optim.Adam(
actor_params,
lr = LR,
# weight_decay = 1e-4,
eps = EPS,
)
optimizer_critic = torch.optim.Adam(
critic_params,
lr = LR,
eps = EPS,
)
optimizer_alpha = torch.optim.Adam(
[sac_loss.log_alpha],
lr = 3e-4,
)
num_frames = 0
pbar = tqdm.tqdm(total = TRAIN_TIMESTEPS)
total_frames = 0
ep_info = {'rewards': [], 'lengths': []}
# Now everything is in place, write the training loop...
breakpoint()
for i, tensordict in enumerate(collector):
collector.update_policy_weights_()
pbar.update(tensordict.numel())
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
total_frames += current_frames
# Optimization steps
if total_frames >= WARMUP_STEPS:
losses = TensorDict({}, batch_size=[UPDATE_STEPS_PER_BATCH])
for i in range(UPDATE_STEPS_PER_BATCH):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != DEVICE:
sampled_tensordict = sampled_tensordict.to(
DEVICE, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()
try:
# Compute loss
loss_td = sac_loss(sampled_tensordict)
except KeyError:
raise Exception(f"Check {sampled_tensordict}\n{obs_act_module(sampled_tensordict)['walker', 'obs_act']}")
actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]
# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()
losses[i] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
).detach()
# Update qnet_target params
polyak_updater.step()
if not ((i + 1) % EVAL_INTERVAL):
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_rollout = eval_env.rollout(
MAX_EPISODE_STEPS,
policy_actor,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_reward = eval_rollout["next", "walker", "reward"].sum(-2).mean().item()
logger.info(f"Mean Reward: {eval_reward}")
# logger.info("SUCCESS")
ep_reward_list = []
collector.shutdown()
train_env.close()
Running this code without ParallelEnv
proceeds without errors. However, with ParallelEnv
:
File "/home/n00bcak/Desktop/<path_to_file>/smacc.py", line 433, in <module>
for i, tensordict in enumerate(collector):
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 888, in iterator
tensordict_out = self.rollout()
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/_utils.py", line 480, in unpack_rref_and_invoke_function
return func(self, *args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1007, in rollout
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/common.py", line 2750, in step_and_maybe_reset
tensordict_ = self.maybe_reset(tensordict_)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/common.py", line 2795, in maybe_reset
tensordict = self.reset(tensordict)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/common.py", line 2120, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 809, in _reset
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 60, in decorated_fun
return fun(self, *args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 1518, in _reset
self.shared_tensordicts[i].apply_(
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/base.py", line 4289, in apply_
return self.apply(fn, *others, inplace=True, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/base.py", line 4387, in apply
result = self._apply_nest(
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/_td.py", line 988, in _apply_nest
item_trsf = item._apply_nest(
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/tensordict/_td.py", line 1013, in _apply_nest
item_trsf = fn(item, *_others)
File "/home/n00bcak/Desktop/<path_to_venv>/lib/python3.10/site-packages/torchrl/envs/batched_envs.py", line 1515, in tentative_update
val.copy_(other, non_blocking=self.non_blocking)
RuntimeError: output with shape [3, 1] doesn't match the broadcast shape [3, 3]
I have failed to pinpoint the source of the error after some time, although it appears to be due to a malformed tensordict which emerges from somewhere within torchRL’s source code.
Any insights on this issue are greatly appreciated.
P.S. I have alternatively tried to use MultiSyncDataCollector
:
collector = MultiSyncDataCollector(
[EnvCreator(lambda: train_env)] * 8, # train_env now ONLY uses PettingZoo's Parallel API
policy = dora.to('cpu'), # the explora
frames_per_batch = EXPLORATION_STEPS,
max_frames_per_traj = MAX_EPISODE_STEPS,
total_frames = TRAIN_TIMESTEPS,
device = "cpu",
policy_device = "cpu",
storing_device = "cpu",
env_device = "cpu",
reset_at_each_iter = False
)
Which results in the following error:
Traceback (most recent call last):
File "/home/n00bcak/Desktop/<path_to_script>/smacc.py", line 355, in <module>
collector = MultiSyncDataCollector(
File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torchrl/collectors/collectors.py", line 1514, in __init__
self._run_processes()
File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torchrl/collectors/collectors.py", line 1672, in _run_processes
proc.start()
File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/usr/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/usr/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/usr/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torch/multiprocessing/reductions.py", line 568, in reduce_storage
fd, size = storage._share_fd_cpu_()
File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torch/storage.py", line 304, in wrapper
return fn(self, *args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/python3.10/site-packages/torch/storage.py", line 374, in _share_fd_cpu_
return super()._share_fd_cpu_(*args, **kwargs)
RuntimeError: _share_fd_: only available on CPU
Any insights on this bug are appreciated as well!