CosTrader Env from scratch... and transform problem

Hello Here is a custom Env. I would like to discussmy mystakes because it pass “check_env_specs” (see Tests at the bottom of the code) but not after transformations…
CosTrader.py

import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.data import Bounded, Composite, Unbounded, Categorical, Binary
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs

from typing import Optional

#############
rendering = True
#############
if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size=None) -> TensorDictBase:
    if batch_size is None:
        batch_size = []
    # f(x) = a.cos(b(x-h))+k
    td = TensorDict({"params": TensorDict({"a": 0.2, # ampli
                                           "b": 1, # freq
                                           "h": 0.5, # phase
                                           "k": 1.25, # niveau
                                           "dt": 0.05},
                                          [],) },
                    [],)

    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

def make_composite_from_td(td):
    composite = Composite({key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase)
                               else Unbounded( dtype=tensor.dtype,
                                               device=tensor.device,
                                               shape=tensor.shape )
                               for key, tensor in td.items()},
                               shape=td.shape )
    return composite


class CosTraderEnv(EnvBase):
    batch_locked = False
    AUTO_UNWRAP_TRANSFORMED_ENV = False
    
    
    def __init__(self, td_params=None, seed=None, device="cpu", batch_size = []):
        super().__init__(device=device, batch_size=batch_size)
        
        if td_params is None:
            td_params = self.gen_params(batch_size)
        
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        self.screen = None
        self.clock = None

    def _step(self,tensordict):
        a = tensordict["params", "a"]
        b = tensordict["params", "b"]
        h = tensordict["params", "h"]
        k = tensordict["params", "k"]
        dt = tensordict["params", "dt"]
        
        bid   = tensordict["bid"]      # cours
        angle = tensordict["angle"]    # tg à la courbe
        posA  = tensordict["posA"]     # position Achat
        posV  = tensordict["posV"]     # position Vente
        solde = tensordict["solde"]    # solde
        t     = tensordict["t"]        # temps
        action_mask= tensordict["action_mask"]
        
        u = tensordict["action"].squeeze(-1)
        
        #action:  0 = close,   1 = rien,    2 = Achat
        C = torch.where(u == 0, 1, 0)        
        R = torch.where(u == 1, 1, 0)
        A = torch.where(u == 2, 1, 0)
        V = torch.where(u == 3, 1, 0)
        
        no_pos = torch.where((posA == 0)&(posV == 0), 1, 0)
        in_posA = torch.where(posA >  0, 1, 0)
        in_posV = torch.where(posV >  0, 1, 0)
        
        addA = torch.where(u == 2, bid*no_pos, 0)
        addV = torch.where(u == 3, bid*no_pos, 0) 

        # add new pos to old pos
        new_posA = posA + addA
        new_posV = posV + addV

        # remove closed pos
        new_posA = torch.where(u == 0, 0, new_posA)
        new_posV = torch.where(u == 0, 0, new_posV)
        
        new_t = t + dt        
        new_bid = a*(b*(t-h)).cos()+k  # f(x) = a.cos(b(x-h))+k
        new_angle = (new_bid - bid)/dt
        new_solde = new_bid - posA - posV

        ##################### REWARD ########################
        
        reward_C = C * in_posA *(bid - (posA + posV)) #si posA >0 => posV=0 et inv
        reward_R = (R* (in_posA + in_posV) - R* no_pos) *(new_bid-bid)
        reward_A = A* no_pos *(new_bid-bid)
        reward_V = V* no_pos *(new_bid-bid)
        reward = reward_C *10+ reward_R + reward_A + reward_V
        
        #### mask adaptation                               action en cours
        lut = torch.tensor([  [False, True,  True, True],  # action 0 close
                              [False, True,  True, True],  # action 1 rien 
                              [True, True, False, False],  # action 2 achat
                              [True, True, False, False],  # action 3 vente
                              ], dtype=torch.bool)

        new_mask = lut[u]
        

        #done = C.bool()    ?
        done = torch.zeros_like(reward, dtype=torch.bool)
        
        nextTD = TensorDict({"params": tensordict["params"],
                             "angle": new_angle,
                             "bid": new_bid,
                             "posA": new_posA,
                             "posV": new_posV,
                             "t": new_t,
                             "solde":new_solde,
                             "reward": reward,
                             "done": done,
                             "action_mask":new_mask},
                            tensordict.shape,)
        
        if rendering:
            self.state = (angle.tolist(),bid.tolist(),new_posA.tolist(),new_posV.tolist())
            self.last_u = 0
            self.render()
        
        return nextTD


    def _reset(self, tensordict):
        if tensordict is None or tensordict.is_empty():
            tensordict = self.gen_params(batch_size=self.batch_size)
        t = torch.zeros(tensordict.shape, device = self.device)
        posA = torch.zeros(tensordict.shape, device = self.device)
        posV = torch.zeros(tensordict.shape, device = self.device)
        solde= torch.zeros(tensordict.shape, device = self.device)
        a = tensordict["params", "a"]
        b = tensordict["params", "b"]
        h = torch.rand(tensordict.shape)*4-2
        k = tensordict["params", "k"]
        dt = tensordict["params", "dt"]
        tensordict["params", "h"]=h
        bid =  a* (b*(t-h)).cos() + k #a.cos(b(x-h))+k
        angle = torch.zeros(tensordict.shape, device = self.device) #...
        
        new_action_mask = self._make_action_mask()


        out = TensorDict({"params": tensordict["params"],
                          "angle": angle,
                          "bid": bid,
                          "posA": posA,
                          "posV": posV,
                          "t": t,
                          "solde":solde,
                          "action_mask":new_action_mask},
                         batch_size=tensordict.shape,)

        if rendering:
            self.last_u = None 
            self.state = (angle.tolist(),bid.tolist(),posA.tolist(),posV.tolist())
            self.render()

        return out

    def _make_spec(self, td_params):
        self.batch_size = getattr(self, "batch_size", td_params.shape)
        if not isinstance(self.batch_size, torch.Size):
            self.batch_size = torch.Size(self.batch_size)

        # Action spec
        self.action_spec = Categorical(
            n=4,
            shape=(*self.batch_size, 1)
        )

        # Observation spec
        self.observation_spec = Composite(
            angle=Bounded(
                low=-torch.pi/2,
                high=torch.pi/2,
                shape=self.batch_size,
                dtype=torch.float32
            ),
            bid=Unbounded(
                shape=self.batch_size,
                dtype=torch.float32
            ),
            posA=Unbounded(
                shape=self.batch_size,
                dtype=torch.float32
            ),
            posV=Unbounded(
                shape=self.batch_size,
                dtype=torch.float32
            ),
            t=Bounded(
                low=0,
                high=1000,
                shape=self.batch_size,
                dtype=torch.float32
            ),
            solde=Unbounded(
                shape=self.batch_size,
                dtype=torch.float32
            ),
            action_mask=Binary(
                n=4,
                dtype=torch.bool,
                shape=(*self.batch_size, 4)
            ),
            params=make_composite_from_td(td_params["params"]),
            shape=self.batch_size
        )

        self.state_spec = self.observation_spec.clone()

        self.reward_spec = Unbounded(
            shape=(*self.batch_size, 1)
        )
            
    gen_params = staticmethod(gen_params)
    
    def _make_action_mask(self):
        mask = torch.tensor([False, True, True, True])  # n=4
        mask = mask.expand(*self.batch_size, 4)
        return mask
    

    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def get_obskeys(self):
        return ["angle", "bid","posA","t","solde"]

    def render(self):
        
        if self.screen is None:
            self.screen_dim = 400
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_dim,self.screen_dim))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                raise SystemExit

        self.surf = pygame.Surface((self.screen_dim, self.screen_dim))
        self.surf.fill((255, 255, 255))

        bound = 2.2
        scale = self.screen_dim / (bound * 2)
        offset = self.screen_dim // 2

        rod_length = 1 * scale
        rod_width = 0.02 * scale
        l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
        coords = [(l, b), (l, t), (r, t), (r, b)]
        transformed_coords = []
        for c in coords:
            try:
                c = pygame.math.Vector2(c).rotate_rad(self.state[0])# + np.pi / 2)
            except:
                c = pygame.math.Vector2(c).rotate_rad(self.state[0][0])# + np.pi / 2)
            c = (c[0] + offset, c[1] + offset)
            transformed_coords.append(c)
        gfxdraw.aapolygon(self.surf, transformed_coords, (204, 77, 77))
        gfxdraw.filled_polygon(self.surf, transformed_coords, (204, 77, 77))

        gfxdraw.aacircle(self.surf, offset, offset, int(rod_width / 2), (204, 77, 77))
        gfxdraw.filled_circle(
            self.surf, offset, offset, int(rod_width / 2), (204, 77, 77)
        )

        # drawing bid
        gfxdraw.filled_circle(self.surf, 10,int(9*(   1.4   -1)* scale), int(0.05 * scale), (0, 0, 255))
        try:
            bid = self.state[1]
            gfxdraw.filled_circle(self.surf, 5,int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
        except:
            bid = self.state[1][0]
            #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
            gfxdraw.filled_circle(self.surf, 5,int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))



        # drawing posA        
        try:
            vac = self.state[2]
            gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (0, 255, 0))
        except:
            vac = self.state[2][0]
            #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
            gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (0, 255, 0))

        # drawing posV        
        try:
            vac = self.state[3]
            gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (255, 0, 0))
        except:
            vac = self.state[3][0]
            #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0))
            gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (255, 0, 0))


        #print(bid,vac)
        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
       
        
        self.clock.tick(30)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()



if __name__ == "__main__":
    ### Tests ###
    env = CosTraderEnv(batch_size = torch.Size([]))
    check_env_specs(env)
    env = CosTraderEnv(batch_size = torch.Size([10]))
    check_env_specs(env)

    print("\n* observation_spec:", env.observation_spec)
    print("\n* action_spec:", env.action_spec)
    print("\n* reward_spec:", env.reward_spec)
    
    print("\n* random action 5: \n", env.action_spec.rand(torch.Size([5])))
    print("\n* random obs 5: \n", env.observation_spec.rand(torch.Size([5])))
    td = env.reset()
    print("\n* reset tensordict", td)

    td = env.rand_step(td)
    print("\n* random step tensordict", td)
    env.close()

And here are the transformations:

import CosTrader
env = CosTrader.CosTraderEnv(batch_size = torch.Size([10]))

# Transform ******************
obs_keys = env.get_obskeys()
env = TransformedEnv(
    env,
    # Unsqueezes the observations that we will concatenate    
    UnsqueezeTransform(
        dim=-1,
        in_keys=obs_keys,
        in_keys_inv=obs_keys,
    ),
)
cat_transform = CatTensors(
    in_keys=obs_keys, dim=-1, out_key="observation", del_keys=False
)
env.append_transform(cat_transform)
env = TransformedEnv(
    env,
    Compose(
        DoubleToFloat(),
        StepCounter(),
        ActionMask()
    ),
)

And here is the traceback when “check_env_specs” after transform:

Traceback (most recent call last):
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 3931, in update_mask
    mask = mask.expand(_remove_neg_shapes(*self.shape, self.space.n))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (1) must match the existing size (10) at non-singleton dimension 1.  Target sizes: [10, 1, 4].  Tensor sizes: [10, 4]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/fauche/Documents/RenforcementLearning/CosPPO2.py", line 101, in <module>
    check_env_specs(env)
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/utils.py", line 759, in check_env_specs
    real_tensordict = env.rollout(
                      ^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/common.py", line 3361, in rollout
    tensordict = self.reset(tensordict)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/common.py", line 2858, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 1226, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 1640, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 8922, in _reset
    return self._call(tensordict_reset)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 8915, in _call
    self.action_spec.update_mask(mask.to(self.action_spec.device))
  File "/home/fauche/Documents/IAvenv/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 3933, in update_mask
    raise RuntimeError("Cannot expand mask to the desired shape.") from err
RuntimeError: Cannot expand mask to the desired shape.

when changing the action_mask shape to:

            ...    
            action_mask=Binary(
                n=4,
                dtype=torch.bool,
                shape=(*self.batch_size, 1,4)
            ),
            ... 

and function _make_action_mask to:

    def _make_action_mask(self):
        mask = torch.tensor([False, True, True, True])  # n=4
        mask = mask.expand(*self.batch_size,1, 4)
        return mask

the transformed env pass “check_env_specs” for unbatched and batched ones…the problem now is in the training loop ! (to be continued…)