PPO with Categorical Action... help

Hello, I’m looking for examples or tutorials on PPO using torchrl with Categorical Actions for a custom Env. I’ve tried to adapt the inverted-double-pendulum tutorial but without success. If anyone can help me I would be very grateful…

Could you describe what exactly you’ve tried so far and where you were stuck?

This could help others to chime in suggesting potential next steps etc.

Ok, first here is a simple Custom Env called “CosTrader” the idea is a cosinus function is running in time, an agent “the Trader” has to learn how to buy at down and close at up. The agent has a categorical action with shape = 1 and n=3 (he take only one action at time and can 0= close position, 1= do nothing, 2= open a buy position). Of course it is not for financial purpose but for a learning purpose! Later I want to introduce noise, varying level(k in params), etc… Nb:I have to make a mask on action but not implemented.
So here is the code “CosTrader.py”:

import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.data import Bounded, Composite, Unbounded, Categorical
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs

from typing import Optional
import random

#############
rendering = True
#############
if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size=None) -> TensorDictBase:
    if batch_size is None:
        batch_size = []
    # f(x) = a.cos(b(x-h))+k
    td = TensorDict({"params": TensorDict({"a": 0.2, # ampli
                                           "b": 1, # freq
                                           "h": random.uniform(0,1), # phase
                                           "k": 1.25, # niveau
                                           "dt": 0.05},
                                          [],) },
                    [],)

    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

def make_composite_from_td(td):
    # custom funtion to convert a tensordict in a similar spec structure
    # of unbounded values.
    composite = Composite({key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase)
                               else Unbounded( dtype=tensor.dtype,
                                               device=tensor.device,
                                               shape=tensor.shape )
                               for key, tensor in td.items()},
                               shape=td.shape )
    return composite


class CosTraderEnv(EnvBase):
    batch_locked = False
    AUTO_UNWRAP_TRANSFORMED_ENV = False
    
    
    def __init__(self, td_params=None, seed=None, device="cpu", batch_size = []):
        super().__init__(device=device, batch_size=batch_size)
        
        if td_params is None:
            td_params = self.gen_params(batch_size)
        
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        self.screen = None
        self.clock = None

    def _step(self,tensordict):
        a = tensordict["params", "a"]
        b = tensordict["params", "b"]
        h = tensordict["params", "h"]
        k = tensordict["params", "k"]
        dt = tensordict["params", "dt"]

        bid   = tensordict["bid"]      # cours
        angle = tensordict["angle"]    # tg à la courbe
        posA  = tensordict["posA"]     # position Achat
        solde = tensordict["solde"]    # solde
        t     = tensordict["t"]        # temps
        
        u = tensordict["action"].squeeze(-1)
        #action:  0 = close,   1 = rien,    2 = Achat

        C = torch.where(u == 0, 1, 0)
        A = torch.where(u == 2, 1, 0)
        R = torch.where(u == 1, 1, 0)
        
        no_pos = torch.where(posA == 0, 1, 0)
        in_pos = torch.where(posA >  0, 1, 0)
        add = torch.where(u == 2, bid*no_pos, 0)        
        new_posA = posA + add
        new_posA = torch.where(u == 0, 0, new_posA)
        
        new_t = t + dt        
        new_bid = a*(b*(t-h)).cos()+k  # f(x) = a.cos(b(x-h))+k
        new_angle = (new_bid - bid)/dt
        new_solde = new_bid - posA

        ####
        #reward = torch.zeros_like(u, dtype=torch.float32)
        reward_C = C* in_pos *(bid - posA)
        reward_R = (R* in_pos - R* no_pos) *(new_bid-bid)
        reward_A = A* no_pos *(new_bid-bid)
        reward = reward_C + reward_R + reward_A
        
        ####
        
        done = torch.zeros_like(reward, dtype=torch.bool)
        
        nextTD = TensorDict({"params": tensordict["params"],
                             "angle": new_angle,
                             "bid": new_bid,
                             "posA": new_posA,
                             "t": new_t,
                             "solde":new_solde,
                             "reward": reward,
                             "done": done,},
                            tensordict.shape,)

        if rendering:
            self.state = (angle.tolist(),bid.tolist(),new_posA.tolist())
            self.last_u = 0
            self.render()
        
        return nextTD


    def _reset(self, tensordict):
        if tensordict is None or tensordict.is_empty():
            tensordict = self.gen_params(batch_size=self.batch_size)
        t = torch.zeros(tensordict.shape, device = self.device)
        posA = torch.zeros(tensordict.shape, device = self.device)
        solde= torch.zeros(tensordict.shape, device = self.device)
        a = tensordict["params", "a"]
        b = tensordict["params", "b"]
        h = tensordict["params", "h"]
        k = tensordict["params", "k"]
        dt = tensordict["params", "dt"]
        
        bid =  a* (b*(t-h)).cos() + k #a.cos(b(x-h))+k
        angle = torch.zeros(tensordict.shape, device = self.device) #...

        out = TensorDict({"params": tensordict["params"],
                          "angle": angle,
                          "bid": bid,
                          "posA": posA,
                          "t": t,
                          "solde":solde},
                         batch_size=tensordict.shape,)

        if rendering:
            self.last_u = None 
            self.state = (angle.tolist(),bid.tolist(),posA.tolist())
            self.render()

        return out

        
    def _make_spec(self,td_params):
        # Under the hood, this will populate self.output_spec["observation"]
        self.observation_spec = Composite(angle=Bounded(low = -torch.pi/2,
                                                        high = torch.pi/2,
                                                        shape = (),
                                                        dtype = torch.float32),
                                          bid=Unbounded(shape = (),
                                                        dtype = torch.float32),
                                          posA=Unbounded(shape = (),
                                                         dtype = torch.float32),
                                          t=Bounded(low = 0,
                                                    high = 1000,
                                                    shape = (),
                                                    dtype = torch.float32),
                                          solde=Unbounded(shape = (),
                                                          dtype = torch.float32),
                                          params=make_composite_from_td(td_params["params"]),
                                          shape= ())
        
        self.state_spec = self.observation_spec.clone()
        
        self.action_spec = Categorical(n = 3, shape = (*td_params.shape,1))
        
        self.reward_spec = Unbounded(shape = (*td_params.shape, 1))
        
    gen_params = staticmethod(gen_params)
    
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def get_obskeys(self):
        return ["angle", "bid","posA","t","solde"]

    def render(self):
        
        if self.screen is None:
            self.screen_dim = 400
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_dim,self.screen_dim))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                raise SystemExit

        self.surf = pygame.Surface((self.screen_dim, self.screen_dim))
        self.surf.fill((255, 255, 255))

        bound = 2.2
        scale = self.screen_dim / (bound * 2)
        offset = self.screen_dim // 2

        rod_length = 1 * scale
        rod_width = 0.02 * scale
        l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
        coords = [(l, b), (l, t), (r, t), (r, b)]
        transformed_coords = []
        for c in coords:
            try:
                c = pygame.math.Vector2(c).rotate_rad(self.state[0])# + np.pi / 2)
            except:
                c = pygame.math.Vector2(c).rotate_rad(self.state[0][0])# + np.pi / 2)
            c = (c[0] + offset, c[1] + offset)
            transformed_coords.append(c)
        gfxdraw.aapolygon(self.surf, transformed_coords, (204, 77, 77))
        gfxdraw.filled_polygon(self.surf, transformed_coords, (204, 77, 77))

        gfxdraw.aacircle(self.surf, offset, offset, int(rod_width / 2), (204, 77, 77))
        gfxdraw.filled_circle(
            self.surf, offset, offset, int(rod_width / 2), (204, 77, 77)
        )

        # drawing bid
        
        try:
            bid = self.state[1]
            gfxdraw.filled_circle(self.surf, 5,int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
        except:
            bid = self.state[1][0]
            #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
            gfxdraw.filled_circle(self.surf, 5,int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))



         # drawing posA        
        try:
            vac = self.state[2]
            gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (0, 255, 0))
        except:
            vac = self.state[2][0]
            #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
            gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (0, 255, 0))

        #print(bid,vac)
        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
       
        
        self.clock.tick(30)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()



if __name__ == "__main__":
    ### Tests ###
    env = CosTraderEnv()#batch_size = torch.Size([8]))
    check_env_specs(env)
    print("\n* observation_spec:", env.observation_spec)
    print("\n* action_spec:", env.action_spec)
    print("\n* reward_spec:", env.reward_spec)
    
    print("\n* random action 5: \n", env.action_spec.rand(torch.Size([5])))
    
    td = env.reset()
    print("\n* reset tensordict", td)

    td = env.rand_step(td)
    print("\n* random step tensordict", td)
    env.close()      

Secondly I want to train my agent with a PPO (and honestly I still have troubles understanding all the workings of it), so I tried to adapt the “DoubleInvertedPendulum” witch is a continuous action.
Here is the code “CosPPO.py” in progress (but not quickly…)

from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torch.distributions import Normal

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    StepCounter,
    TransformedEnv,
    UnsqueezeTransform,
    CatTensors,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
#from tqdm import tqdm


#is_fork = multiprocessing.get_start_method() == "fork"
device = (torch.device("cpu"))
"""          
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu"))
"""    
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0


frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 10_000


sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4


import CosTrader
env = CosTrader.CosTraderEnv()#batch_size = torch.Size([8]))
# Transform ******************
obs_keys = env.get_obskeys()
env = TransformedEnv(
    env,
    # Unsqueezes the observations that we will concatenate    
    UnsqueezeTransform(
        dim=-1,
        in_keys=obs_keys,
        in_keys_inv=obs_keys,
    ),
)
cat_transform = CatTensors(
    in_keys=obs_keys, dim=-1, out_key="observation", del_keys=False
)
env.append_transform(cat_transform)
env = TransformedEnv(
    env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)

# ******************************


env.transform[-3].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)


print("normalization constant shape:", env.transform[-3].loc.shape)

print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)

check_env_specs(env)

rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)

######################################################################
# Batching computations

batch_size = 10   
td = env.reset(env.gen_params(batch_size=[batch_size]))
print("\n* reset (batch size of 10)", td)

td = env.rand_step(td)
print("\n* rand step (batch size of 10)", td)


# Hyperparamètres ---------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_cells = 64  # number of cells in each layer i.e. output dim.
lr = 2e-3
max_grad_norm = 1.0
# -------------------------------------------------------------

policy_net = nn.Sequential(nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
                          NormalParamExtractor())

policy_module = TensorDictModule(policy_net,
                                 in_keys = ["observation"],
                                 out_keys = ["loc", "scale"])

policy_module = ProbabilisticActor(module=policy_module,
                                   spec = env.action_spec,
                                   in_keys = ["loc", "scale"],
                                   out_keys= ["action"],
                                   distribution_class = Normal,#TanhNormal,
                                   #distribution_kwargs = {"min": -1, "max": 1},
                                   return_log_prob = True)

critic_net = nn.Sequential(nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(1, device=device))

critic = ValueOperator(module = critic_net,
                             in_keys = ["observation"])

print("\n* Running policy:", policy_module(env.reset()))
print("\n* Running criticvalue:", critic(env.reset()))


# Data paramètres ---------------------------------------------
frame_skip = 1  
frames_per_batch = 150 // frame_skip
total_frames = 40_050 // frame_skip
minibatch_size = 64
# -------------------------------------------------------------

collector = SyncDataCollector(env,
                              policy_module,
                              frames_per_batch=frames_per_batch,
                              total_frames=total_frames,
                              split_trajs=False,
                              device=device,)


replay_buffer = ReplayBuffer(storage = LazyTensorStorage(frames_per_batch),
                             sampler = SamplerWithoutReplacement(),
                             batch_size = minibatch_size)

# PPO paramètres ---------------------------------------------
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (0.2)  # clip value for PPO loss: see the equation in the intro for more context.
gamma = 0.99 # [0 ,1]
lmbda = 0.95 # [0 ,1]
entropy_eps = 1e-4
# -------------------------------------------------------------

advantage_module = GAE(gamma = gamma,
                       lmbda = lmbda,
                       value_network = critic,
                       average_gae = True)

loss_module = ClipPPOLoss(actor = policy_module,
                          critic = critic,
                          advantage_key = "advantage",
                          clip_epsilon = clip_epsilon,
                          entropy_bonus = bool(entropy_eps),
                          entropy_coeff = entropy_eps,
                          # these keys match by default but we set this for completeness
                          value_target_key = advantage_module.value_target_key,
                          critic_coeff = 1.0,
                          gamma = gamma,
                          loss_critic_type = "smooth_l1",)

optim = torch.optim.Adam(loss_module.parameters(), lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_frames // frames_per_batch, 0.0)

logs = defaultdict(list)

total=total_frames * frame_skip
faits = 0
eval_str = ""

print("TRAINING LOOP ***********************************************************")
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    #print(tensordict_data["action"])
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        
        data_view = tensordict_data.reshape(-1)
        
        replay_buffer.extend(data_view.cpu())
        
        for _ in range(frames_per_batch // minibatch_size):
            subdata = replay_buffer.sample(minibatch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (loss_vals["loss_objective"] + loss_vals["loss_critic"] + loss_vals["loss_entropy"])

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            
            # this is not strictly mandatory but it's good practice to keep gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    faits += tensordict_data.numel() * frame_skip
    cum_reward_str = (f"average reward={logs['reward'][-1]: 4.4f}")# (init={logs['reward'][0]: 4.4f})")
    #logs["step_count"].append(tensordict_data["step_count"].max().item())
    #stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(eval_rollout["next", "reward"].sum().item())
            #logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (f"cum reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                        #f"step-count: {logs['eval step_count'][-1]}"
                        )
            
            del eval_rollout
    collector.update_policy_weights_() ## ??
    progres = 100*faits/total
    print(f"{progres: 3.1f}%",", ".join([eval_str, cum_reward_str, lr_str]))
    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()

when running code here are the tests:

(IAvenv) fauche@debian:~/Documents/RenforcementLearning$ python CosPPO.py
pygame 2.6.1 (SDL 2.28.4, Python 3.11.2)
Hello from the pygame community. https://www.pygame.org/contribute.html
/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:831: FutureWarning: The default behavior of TransformedEnv will change in version 0.9. Nested TransformedEnvs will no longer be automatically unwrapped by default. To prepare for this change, use set_auto_unwrap_transformed_env(val: bool) as a decorator or context manager, or set the environment variable AUTO_UNWRAP_TRANSFORMED_ENV to 'False'.
  instance: EnvBase = super(_EnvPostInit, self).__call__(*args, **kwargs)
normalization constant shape: torch.Size([5])
observation_spec: Composite(
    angle: BoundedContinuous(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    bid: UnboundedContinuous(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    posA: UnboundedContinuous(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    t: BoundedContinuous(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    solde: UnboundedContinuous(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    params: Composite(
        a: UnboundedContinuous(
            shape=torch.Size([]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        b: UnboundedDiscrete(
            shape=torch.Size([]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
                high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
            device=cpu,
            dtype=torch.int64,
            domain=discrete),
        h: UnboundedContinuous(
            shape=torch.Size([]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        k: UnboundedContinuous(
            shape=torch.Size([]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        dt: UnboundedContinuous(
            shape=torch.Size([]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        device=cpu,
        shape=torch.Size([]),
        data_cls=None),
    observation: UnboundedContinuous(
        shape=torch.Size([5]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    step_count: BoundedDiscrete(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
        device=cpu,
        dtype=torch.int64,
        domain=discrete),
    device=cpu,
    shape=torch.Size([]),
    data_cls=None)
reward_spec: UnboundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)
input_spec: Composite(
    full_state_spec: Composite(
        angle: BoundedContinuous(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        bid: UnboundedContinuous(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        posA: UnboundedContinuous(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        t: BoundedContinuous(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        solde: UnboundedContinuous(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        params: Composite(
            a: UnboundedContinuous(
                shape=torch.Size([]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            b: UnboundedDiscrete(
                shape=torch.Size([]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
                    high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
                device=cpu,
                dtype=torch.int64,
                domain=discrete),
            h: UnboundedContinuous(
                shape=torch.Size([]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            k: UnboundedContinuous(
                shape=torch.Size([]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            dt: UnboundedContinuous(
                shape=torch.Size([]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            device=cpu,
            shape=torch.Size([]),
            data_cls=None),
        step_count: BoundedDiscrete(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
            device=cpu,
            dtype=torch.int64,
            domain=discrete),
        device=cpu,
        shape=torch.Size([]),
        data_cls=None),
    full_action_spec: Composite(
        action: Categorical(
            shape=torch.Size([1]),
            space=CategoricalBox(n=3),
            device=cpu,
            dtype=torch.int64,
            domain=discrete),
        device=cpu,
        shape=torch.Size([]),
        data_cls=None),
    device=cpu,
    shape=torch.Size([]),
    data_cls=None)
action_spec (as defined by input_spec): Categorical(
    shape=torch.Size([1]),
    space=CategoricalBox(n=3),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)
2025-08-09 10:51:22,401 [torchrl][INFO]    check_env_specs succeeded! [END]
rollout of three steps: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        angle: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        bid: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                angle: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                bid: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
                params: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False),
                        dt: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                        h: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                        k: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([3]),
                    device=None,
                    is_shared=False),
                posA: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                solde: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                t: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False),
                dt: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                h: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                k: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        posA: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        solde: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        t: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])

* reset (batch size of 10) TensorDict(
    fields={
        angle: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        bid: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                h: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                k: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        posA: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        solde: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        t: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

* rand step (batch size of 10) TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        angle: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        bid: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                angle: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                bid: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False),
                params: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                        dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        h: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        k: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False),
                posA: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                solde: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                t: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                h: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                k: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        posA: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        solde: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        t: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

* Running policy: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        action_log_prob: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        angle: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        bid: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                dt: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                h: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                k: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        posA: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        solde: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        t: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

* Running criticvalue: TensorDict(
    fields={
        angle: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        bid: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                dt: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                h: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                k: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        posA: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        solde: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        t: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

And then I have an error(and certainly many others) in loss_module with this traceback:

Traceback (most recent call last):
  File "/home/fauche/Documents/RenforcementLearning/CosPPO.py", line 190, in <module>
    loss_module = ClipPPOLoss(actor = policy_module,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/objectives/ppo.py", line 1094, in __init__
    super().__init__(
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/objectives/ppo.py", line 490, in __init__
    raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
TypeError: Passing gamma / lambda parameters through the loss constructor is a deprecated feature. To customize your value function, run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`.


Thank you very much for your help !

OK I understand… I take an old tuto as a reference ! I modify “CosPPO.py” and I come back.

Modified code for “CosPPO.py”:

from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torch.distributions import Normal

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    StepCounter,
    TransformedEnv,
    UnsqueezeTransform,
    CatTensors,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
#from tqdm import tqdm


#is_fork = multiprocessing.get_start_method() == "fork"
device = (torch.device("cpu"))
"""          
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu"))
"""    
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0


frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 10_000


sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4


import CosTrader
env = CosTrader.CosTraderEnv()#batch_size = torch.Size([8]))
# Transform ******************
obs_keys = env.get_obskeys()
env = TransformedEnv(
    env,
    # Unsqueezes the observations that we will concatenate    
    UnsqueezeTransform(
        dim=-1,
        in_keys=obs_keys,
        in_keys_inv=obs_keys,
    ),
)
cat_transform = CatTensors(
    in_keys=obs_keys, dim=-1, out_key="observation", del_keys=False
)
env.append_transform(cat_transform)
env = TransformedEnv(
    env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)

# ******************************


env.transform[-3].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)


print("normalization constant shape:", env.transform[-3].loc.shape)

print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)

check_env_specs(env)

rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)

######################################################################
# Batching computations

batch_size = 10   
td = env.reset(env.gen_params(batch_size=[batch_size]))
print("\n* reset (batch size of 10)", td)

td = env.rand_step(td)
print("\n* rand step (batch size of 10)", td)


# Hyperparamètres ---------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_cells = 64  # number of cells in each layer i.e. output dim.
lr = 2e-3
max_grad_norm = 1.0
# -------------------------------------------------------------

policy_net = nn.Sequential(nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
                          NormalParamExtractor())

policy_module = TensorDictModule(policy_net,
                                 in_keys = ["observation"],
                                 out_keys = ["loc", "scale"])

policy_module = ProbabilisticActor(module=policy_module,
                                   spec = env.action_spec,
                                   in_keys = ["loc", "scale"],
                                   out_keys= ["action"],
                                   distribution_class = Normal,#TanhNormal,
                                   #distribution_kwargs = {"min": -1, "max": 1},
                                   return_log_prob = True)

critic_net = nn.Sequential(nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(num_cells, device=device),
                          nn.Tanh(),
                          nn.LazyLinear(1, device=device))

critic = ValueOperator(module = critic_net,
                             in_keys = ["observation"])

print("\n* Running policy:", policy_module(env.reset()))
print("\n* Running criticvalue:", critic(env.reset()))


# Data paramètres ---------------------------------------------
frame_skip = 1  
frames_per_batch = 150 // frame_skip
total_frames = 40_050 // frame_skip
minibatch_size = 64
# -------------------------------------------------------------

collector = SyncDataCollector(env,
                              policy_module,
                              frames_per_batch=frames_per_batch,
                              total_frames=total_frames,
                              split_trajs=False,
                              device=device,)


replay_buffer = ReplayBuffer(storage = LazyTensorStorage(frames_per_batch),
                             sampler = SamplerWithoutReplacement(),
                             batch_size = minibatch_size)

# PPO paramètres ---------------------------------------------
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (0.2)  # clip value for PPO loss: see the equation in the intro for more context.
gamma = 0.99 # [0 ,1]
lmbda = 0.95 # [0 ,1]
entropy_eps = 1e-4
# -------------------------------------------------------------

advantage_module = GAE(gamma = gamma,
                       lmbda = lmbda,
                       value_network = critic,
                       average_gae = True)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network= critic,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_frames // frames_per_batch, 0.0)

logs = defaultdict(list)

total=total_frames * frame_skip
faits = 0
eval_str = ""

print("TRAINING LOOP ***********************************************************")
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    #print(tensordict_data["action"])
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        
        data_view = tensordict_data.reshape(-1)
        
        replay_buffer.extend(data_view.cpu())
        
        for _ in range(frames_per_batch // minibatch_size):
            subdata = replay_buffer.sample(minibatch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (loss_vals["loss_objective"] + loss_vals["loss_critic"] + loss_vals["loss_entropy"])

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            
            # this is not strictly mandatory but it's good practice to keep gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    faits += tensordict_data.numel() * frame_skip
    cum_reward_str = (f"average reward={logs['reward'][-1]: 4.4f}")# (init={logs['reward'][0]: 4.4f})")
    #logs["step_count"].append(tensordict_data["step_count"].max().item())
    #stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(eval_rollout["next", "reward"].sum().item())
            #logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (f"cum reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                        #f"step-count: {logs['eval step_count'][-1]}"
                        )
            
            del eval_rollout
    collector.update_policy_weights_() ## ??
    progres = 100*faits/total
    print(f"{progres: 3.1f}%",", ".join([eval_str, cum_reward_str, lr_str]))
    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()

plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
#plt.plot(logs["step_count"])
#plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
#plt.plot(logs["eval step_count"])
#plt.title("Max step count (test)")
plt.show() 

But It don’t learn, and when looking at the actions I see float, not int ! So I tried safe = True in ProbalisticActor, actions are float but correspond (0.0,1.0,2.0) is it the right way ?

Ok, I’ve changed distribution_class to MaskedCategorical (because I add a dynamic mask, other why a Categorical distribution should be used for the upside version). I’ve changed in_keys too with logits and action_mask in ProbabilisticActor, and the out_keys from net to “logits”. I removed NormalParamExtractor too. And other changes…

Hello
For categorical actions you should definitely use a Categorical distribution.
Here are some pointers:

LMK if you have specific questions about these!

Hello, when taking your action mask example, your Env don’t pass “check_env_specs”:

import torch
from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite
from torchrl.envs.transforms import ActionMask, TransformedEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.utils import check_env_specs

class MaskedEnv(EnvBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.action_spec = Categorical(4)
        self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool))
        self.observation_spec = Composite(obs=Unbounded(3))
        self.reward_spec = Unbounded(1)
        
    def _reset(self, tensordict=None):
        td = self.observation_spec.rand()
        td.update(torch.ones_like(self.state_spec.rand()))
        return td
    
    def _step(self, data):
        td = self.observation_spec.rand()
        mask = data.get("action_mask")
        action = data.get("action")
        mask = mask.scatter(-1, action.unsqueeze(-1), 0)
        td.set("action_mask", mask)
        td.set("reward", self.reward_spec.rand())
        td.set("done", ~mask.any().view(1))
        return td

    def _set_seed(self, seed) -> None:
        pass
torch.manual_seed(0)
base_env = MaskedEnv()
env = TransformedEnv(base_env, ActionMask())
check_env_specs(env)

it return this traceback:

Traceback (most recent call last):
  File "/home/fauche/Documents/RenforcementLearning/testActionMask.py", line 34, in <module>
    check_env_specs(env)
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/utils.py", line 789, in check_env_specs
    raise AssertionError(
AssertionError: The keys of the specs and data do not match:
    - List of keys present in real but not in fake: {('next', 'action_mask')},
    - List of keys present in fake but not in real: set().

Good catch, the action mask should be registered in the observation spec too!