How to instantiate double LSTM + MLP for actor/critic in PPO

Hi !

I am trying to write a model for lunar lander v3 using PPO using double LSTM +MLP, one for actor and one for critic. I successfully implemented the actor but struggling for the critic. I am puzzled as it should not be that hard, after all, actor and critic are very similar. I ended up bypassing ValueOperator to use directly TensorDictSequential + TensorDictModule combo with “state_value” output key but I am not sure this is the way to go. Here are my questions:

1/ Is this the proper way to instantiate the critic?

2/ I ended-up using make_tensordict_primer()twice. Is this the correct ways to do thing?

Here is my full code:

#!/usr/bin/env python3

import numpy as np
import random
import torch as t
import tyro
import warnings

warnings.filterwarnings("ignore")

from tensordict.nn import TensorDictModule, TensorDictSequential

from torchrl.envs import (Compose, DoubleToFloat, EnvCreator, InitTracker, ParallelEnv, StepCounter, TransformedEnv)
from torchrl.envs.transforms import VecNorm
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs
from torchrl.modules import OneHotCategorical, ProbabilisticActor,ValueOperator, LSTMModule, MLP
from dataclasses import dataclass

@dataclass
class Args:
    seed: int = 1
    """seed of the experiment"""

class CustomModel:
    def __init__(self, env):
        self.env = env
        self.device = "cpu"

    def _get_mlp(self, out_features):
        return MLP(num_cells=[16, 16], out_features=out_features, device=self.device)

    def _get_actor_lstm(self, actor_mlp):
        self.actor_lstm = LSTMModule(
            input_size=self.env.observation_spec.shape[0],
            hidden_size=32,
            device=device,
            in_keys=["observation", "rs_h", "rs_c"],
            out_keys=["actor_features", ("next", "rs_h"), ("next", "rs_c")],
        )
        actor_mlp = TensorDictModule(actor_mlp, in_keys=["actor_features"], out_keys=["logits"])
        return TensorDictSequential(self.actor_lstm, actor_mlp)

    def get_actor(self, use_lstm):
        actor_mlp = self._get_mlp(out_features=self.env.action_spec.space.n)
        if use_lstm:
            policy_module = self._get_actor_lstm(actor_mlp)
            self.env.append_transform(self.actor_lstm.make_tensordict_primer())
        else:
            policy_module = TensorDictModule(actor_mlp, in_keys=["observation"], out_keys=["logits"])
        return ProbabilisticActor(
            module=policy_module,
            spec=parallel_env.action_spec,
            in_keys=["logits"],
            distribution_class=OneHotCategorical,
            return_log_prob=True
        )

    def _get_critic_lstm(self, critic_mlp):
        self.critic_lstm = LSTMModule(
            input_size=self.env.observation_spec.shape[0],
            hidden_size=32,
            device=device,
            in_keys=["observation", "rs_h", "rs_c"],
            out_keys=["critic_features", ("next", "rs_h"), ("next", "rs_c")],
        )
        critic_mlp = TensorDictModule(critic_mlp, in_keys=["critic_features"], out_keys=["state_value"])
        return TensorDictSequential(self.critic_lstm, critic_mlp)

    def get_critic(self, use_lstm):
        critic_mlp = self._get_mlp(out_features=1)
        if use_lstm:
            critic_module = self._get_critic_lstm(critic_mlp)
            self.env.append_transform(self.critic_lstm.make_tensordict_primer())
        else:
            critic_net = TensorDictModule(critic_mlp, in_keys=["observation"], out_keys=["critic_logits"])
            critic_module = ValueOperator(module=critic_net, in_keys=["observation"])
        return critic_module

if __name__ == "__main__":
    args = tyro.cli(Args)
    device = "cpu"
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    t.manual_seed(seed)

    parallel_env = TransformedEnv(
        ParallelEnv(8, EnvCreator(lambda: GymEnv("LunarLander-v3", device=device))),
        Compose(
            VecNorm(in_keys=["observation", "reward"], eps=1e-8, new_api=True),
            DoubleToFloat(),
            StepCounter(),
            InitTracker(),
        ),
    )
    check_env_specs(parallel_env)

    model = CustomModel(parallel_env)

    print('Mlp only')
    policy_module = model.get_actor(use_lstm=False)
    critic_module = model.get_critic(use_lstm=False)

    policy_module(parallel_env.reset())
    critic_module(parallel_env.reset())

    print('Actor Lstm')
    policy_module = model.get_actor(use_lstm=True)
    critic_module = model.get_critic(use_lstm=False)
    policy_module(parallel_env.reset())
    critic_module(parallel_env.reset())

    print('Critic Lstm')
    policy_module = model.get_actor(use_lstm=False)
    critic_module = model.get_critic(use_lstm=True)
    policy_module(parallel_env.reset())
    critic_module(parallel_env.reset())

    print('Both Lstm')
    policy_module = model.get_actor(use_lstm=True)
    critic_module = model.get_critic(use_lstm=True)
    policy_module(parallel_env.reset())
    critic_module(parallel_env.reset())

Thanks in advance!