Hi !
I am trying to write a model for lunar lander v3 using PPO using double LSTM +MLP, one for actor and one for critic. I successfully implemented the actor but struggling for the critic. I am puzzled as it should not be that hard, after all, actor and critic are very similar. I ended up bypassing ValueOperator to use directly TensorDictSequential + TensorDictModule combo with “state_value” output key but I am not sure this is the way to go. Here are my questions:
1/ Is this the proper way to instantiate the critic?
2/ I ended-up using make_tensordict_primer()twice. Is this the correct ways to do thing?
Here is my full code:
#!/usr/bin/env python3
import numpy as np
import random
import torch as t
import tyro
import warnings
warnings.filterwarnings("ignore")
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs import (Compose, DoubleToFloat, EnvCreator, InitTracker, ParallelEnv, StepCounter, TransformedEnv)
from torchrl.envs.transforms import VecNorm
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs
from torchrl.modules import OneHotCategorical, ProbabilisticActor,ValueOperator, LSTMModule, MLP
from dataclasses import dataclass
@dataclass
class Args:
seed: int = 1
"""seed of the experiment"""
class CustomModel:
def __init__(self, env):
self.env = env
self.device = "cpu"
def _get_mlp(self, out_features):
return MLP(num_cells=[16, 16], out_features=out_features, device=self.device)
def _get_actor_lstm(self, actor_mlp):
self.actor_lstm = LSTMModule(
input_size=self.env.observation_spec.shape[0],
hidden_size=32,
device=device,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["actor_features", ("next", "rs_h"), ("next", "rs_c")],
)
actor_mlp = TensorDictModule(actor_mlp, in_keys=["actor_features"], out_keys=["logits"])
return TensorDictSequential(self.actor_lstm, actor_mlp)
def get_actor(self, use_lstm):
actor_mlp = self._get_mlp(out_features=self.env.action_spec.space.n)
if use_lstm:
policy_module = self._get_actor_lstm(actor_mlp)
self.env.append_transform(self.actor_lstm.make_tensordict_primer())
else:
policy_module = TensorDictModule(actor_mlp, in_keys=["observation"], out_keys=["logits"])
return ProbabilisticActor(
module=policy_module,
spec=parallel_env.action_spec,
in_keys=["logits"],
distribution_class=OneHotCategorical,
return_log_prob=True
)
def _get_critic_lstm(self, critic_mlp):
self.critic_lstm = LSTMModule(
input_size=self.env.observation_spec.shape[0],
hidden_size=32,
device=device,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["critic_features", ("next", "rs_h"), ("next", "rs_c")],
)
critic_mlp = TensorDictModule(critic_mlp, in_keys=["critic_features"], out_keys=["state_value"])
return TensorDictSequential(self.critic_lstm, critic_mlp)
def get_critic(self, use_lstm):
critic_mlp = self._get_mlp(out_features=1)
if use_lstm:
critic_module = self._get_critic_lstm(critic_mlp)
self.env.append_transform(self.critic_lstm.make_tensordict_primer())
else:
critic_net = TensorDictModule(critic_mlp, in_keys=["observation"], out_keys=["critic_logits"])
critic_module = ValueOperator(module=critic_net, in_keys=["observation"])
return critic_module
if __name__ == "__main__":
args = tyro.cli(Args)
device = "cpu"
seed = args.seed
random.seed(seed)
np.random.seed(seed)
t.manual_seed(seed)
parallel_env = TransformedEnv(
ParallelEnv(8, EnvCreator(lambda: GymEnv("LunarLander-v3", device=device))),
Compose(
VecNorm(in_keys=["observation", "reward"], eps=1e-8, new_api=True),
DoubleToFloat(),
StepCounter(),
InitTracker(),
),
)
check_env_specs(parallel_env)
model = CustomModel(parallel_env)
print('Mlp only')
policy_module = model.get_actor(use_lstm=False)
critic_module = model.get_critic(use_lstm=False)
policy_module(parallel_env.reset())
critic_module(parallel_env.reset())
print('Actor Lstm')
policy_module = model.get_actor(use_lstm=True)
critic_module = model.get_critic(use_lstm=False)
policy_module(parallel_env.reset())
critic_module(parallel_env.reset())
print('Critic Lstm')
policy_module = model.get_actor(use_lstm=False)
critic_module = model.get_critic(use_lstm=True)
policy_module(parallel_env.reset())
critic_module(parallel_env.reset())
print('Both Lstm')
policy_module = model.get_actor(use_lstm=True)
critic_module = model.get_critic(use_lstm=True)
policy_module(parallel_env.reset())
critic_module(parallel_env.reset())
Thanks in advance!