Environments from scratch with Torchrl

Hello, I open this topic … for my contribution here is the simplest one you can imagine!

from typing import Optional
import matplotlib.pyplot as plt

import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase 
from torchrl.envs.utils import check_env_specs

"""
A very very simple objective: An actor appears at a random location x on a line and must join
the center of the line...
"""

################
DEFAULT_X = 1
rendering = True
render_fps= 30
################

if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size=None):
    if batch_size is None:
        batch_size = []

    td = TensorDict({},[])
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

class SimpleEnv(EnvBase):
    batch_locked = False
    def __init__(self, td_params=None, seed=None, device="cpu"):
        super().__init__(device=device, batch_size=[])
        
        if td_params is None:
            td_params = self.gen_params()
        
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        self.screen = None
        self.clock = None
        
    def _step(self,tensordict):
        x = tensordict["x"] 
        
        u = tensordict["action"].squeeze(-1)
        #u = u.clamp(-1, 1)
                

        reward = -torch.abs(x+u)**2#100000

        new_x = x + u
        
        done = torch.zeros_like(reward, dtype=torch.bool)
                                     
        nextTD = TensorDict({"x":new_x,
                             "reward": reward,
                             "done": done},
                            tensordict.shape)

        if rendering:
            state = {"x":tensordict["x"].tolist()}
            self.render(state)
        return nextTD
    
    def _reset(self, tensordict):
        if tensordict is None:            
            tensordict = self.gen_params(batch_size=self.batch_size)

        high_x = torch.tensor(DEFAULT_X, device=self.device)
        low_x = -high_x

        x = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        
        out = TensorDict({"x":x},
                         batch_size=tensordict.shape,)

        if rendering:
            state = {"x":out["x"].tolist()}
            self.render(state)
        return out
    
    gen_params = staticmethod(gen_params)

    def _make_spec(self,td):        
        self.observation_spec = CompositeSpec(x=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              shape=())

        #self.state_spec = self.observation_spec.clone()
        
        self.action_spec = BoundedTensorSpec(minimum = -DEFAULT_X/100,
                                             maximum =  DEFAULT_X/100,
                                             shape = (1,),
                                             dtype = torch.float32)
        
        self.reward_spec = UnboundedContinuousTensorSpec(shape = (*td.shape,1))

        
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def render(self,state):
        self.screen_w = 600
        self.screen_h = 400

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_w, self.screen_h))
            pygame.display.set_caption('View from first élément batch')
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((self.screen_w, self.screen_h))
        self.surf.fill((255, 255, 255))

        bound = DEFAULT_X
        scale = self.screen_w / (bound * 2)
        offset = self.screen_w // 2
        agent_width = int(0.05 * scale)

        gfxdraw.hline(self.surf, 0, self.screen_w, self.screen_h//2, (0, 0, 0))
        
        if type(state["x"]) != float:
            x = state["x"][0]
        else:
            x = state["x"]

        # drawing agent    
        gfxdraw.aacircle(self.surf, int(x* scale )+offset, self.screen_h//2, int(agent_width / 2), (255, 0, 0))
        gfxdraw.filled_circle(self.surf, int(x* scale)+offset, self.screen_h//2, int(agent_width / 2), (255, 0, 0))

        # drawing center
        gfxdraw.aacircle(self.surf, offset, self.screen_h//2, int(agent_width / 4), (0, 0, 0))
       

        #self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        
        self.clock.tick(render_fps)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()

    def get_obskeys(self):
        return ["x"]


if __name__ == "__main__":
    # Tests on environnement
    env = SimpleEnv()    
    check_env_specs(env)
    
    print("\n* observation_spec:", env.observation_spec)
    print("\n* action_spec:", env.action_spec)
    print("\n* reward_spec:", env.reward_spec)
    print("\n* random action: \n", env.action_spec.rand())
    td = env.reset()
    print("\n* reset",td)
    td = env.rand_step(td)
    print("\n* random step tensordict", td)

    rollout = env.rollout(3)
    print("\n* rollout of three steps:", rollout)
    print("\n* Shape of the rollout TensorDict:", rollout.batch_size)
    
    batch_size = 10   
    td = env.reset(env.gen_params(batch_size=[batch_size]))
    print("\n* reset (batch size of 10)", td)
    env.close()    

and a minimal network for training the agent on this minimal env:

import torch
from torch import nn

from collections import defaultdict
import matplotlib.pyplot as plt

from tensordict.nn import TensorDictModule
from torchrl.envs import CatTensors,TransformedEnv,UnsqueezeTransform
                          
from torchrl.envs.transforms.transforms import _apply_to_composite
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec

import Simple1D1AgentEnv
env = Simple1D1AgentEnv.SimpleEnv()


obs_keys = env.get_obskeys()
env = TransformedEnv(
    env,
    UnsqueezeTransform(
        unsqueeze_dim=-1,
        in_keys=obs_keys,
        in_keys_inv=obs_keys,
    ),
)
cat_transform = CatTensors(
    in_keys=obs_keys, dim=-1, out_key="observation", del_keys=False
)
env.append_transform(cat_transform)


torch.manual_seed(0)
env.set_seed(0)
E = 64
net = nn.Sequential(
    nn.LazyLinear(E),
    nn.Tanh(),
    nn.LazyLinear(E),
    nn.Tanh(),
    nn.LazyLinear(E),
    nn.Tanh(),
    nn.LazyLinear(1),
)
policy = TensorDictModule(
    net,
    in_keys=["observation"],
    out_keys=["action"],
)

optim = torch.optim.Adam(policy.parameters(), lr=2e-3)

batch_size = 32
frames = 10_000
iterations = frames // batch_size

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, frames)
logs = defaultdict(list)

for i in range(iterations):
    init_td = env.reset(env.gen_params(batch_size=[batch_size]))
    rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)
    traj_return = rollout["next", "reward"].mean()
    (-traj_return).backward()
    gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    progres = 100*i/iterations
    print(
        f"{progres: 3.1f}%, reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
    logs["grad_norm"].append(gn)
    scheduler.step()

env.close()
def plot():
    import matplotlib    
    from matplotlib import pyplot as plt
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.plot(logs["return"])
    plt.title("returns")
    plt.xlabel("iteration")

    plt.subplot(1, 3, 2)
    plt.plot(logs["last_reward"])
    plt.title("last reward")
    plt.xlabel("iteration")

    plt.subplot(1, 3, 3)
    plt.plot(logs["grad_norm"])
    plt.title("grad_norm")
    plt.xlabel("iteration")

    plt.show()
plot()

Thanks a lot for tutorial
https://pytorch.org/rl/tutorials/pendulum.html

Now I try to pass the environnement at 2 dimensions, a point in a random place in a plane but as you can see I need to “help” with trigonometrie in _step if I use a 1 dimension action… and with this help 1 training pass and the network have the solution…

from typing import Optional
import matplotlib.pyplot as plt

import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase 
from torchrl.envs.utils import check_env_specs

"""
A very simple objective: An actor appears at a random location xy on a plane and must join
the center of the plane...
"""

################
DEFAULT_X = 1
rendering = True
render_fps= 30
################

if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size=None):
    if batch_size is None:
        batch_size = []

    td = TensorDict({},[])
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

class Simple2DEnv(EnvBase):
    batch_locked = False
    def __init__(self, td_params=None, seed=None, device="cpu"):
        super().__init__(device=device, batch_size=[])
        
        if td_params is None:
            td_params = self.gen_params()
        
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        self.screen = None
        self.clock = None
        
    def _step(self,tensordict):
        x = tensordict["x"]
        y = tensordict["y"] 

        u = tensordict["action"].squeeze(-1)
        u = u.clamp(-1, 1)
        
        # Formule distance beetween 2 points : d=√((x2 – x1)² + (y2 – y1)²)
        # center of screen => x1=0 and y1=0
        # rewars estimate -dist(x,center) 
        d = torch.rsqrt( torch.square(x) + torch.square(y) ) 
        
        #sin = y/d
        #cos = x/d
        
        reward = -d

        new_x = x*(1 - u/d)  #x - u* cos
        new_y = y*(1 - u/d)  #y - u* sin
        new_x =  torch.nan_to_num(new_x)
        new_y =  torch.nan_to_num(new_y)
        
        done = torch.zeros_like(reward, dtype=torch.bool)
                                     
        nextTD = TensorDict({"x":new_x,
                             "y":new_y,
                             "reward": reward,
                             "done": done},
                            tensordict.shape)

        if rendering:
            state = {"x":tensordict["x"].tolist(),"y":tensordict["y"].tolist()}
            self.render(state)
        return nextTD
    
    def _reset(self, tensordict):
        if tensordict is None:            
            tensordict = self.gen_params(batch_size=self.batch_size)

        high_x = torch.tensor(DEFAULT_X, device=self.device)
        low_x = -high_x

        x = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        y = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        
        out = TensorDict({"x":x, "y":y},
                         batch_size=tensordict.shape,)

        if rendering:
            state = {"x":out["x"].tolist(),"y":out["y"].tolist()}
            self.render(state)
        return out
    
    gen_params = staticmethod(gen_params)

    def _make_spec(self,td):        
        self.observation_spec = CompositeSpec(x=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              
                                              y=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              shape=())

        
        self.action_spec = BoundedTensorSpec(minimum = -DEFAULT_X/1000,
                                             maximum =  DEFAULT_X/1000,
                                             shape = (1,),
                                             dtype = torch.float32)
        
        self.reward_spec = UnboundedContinuousTensorSpec(shape = (*td.shape,1))

        
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def render(self,state):
        self.screen_w = 600
        self.screen_h = 600

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_w, self.screen_h))
            pygame.display.set_caption('View from first élément batch')
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((self.screen_w, self.screen_h))
        self.surf.fill((255, 255, 255))

        bound = DEFAULT_X
        scale = self.screen_w / (bound * 2)
        offset = self.screen_w // 2
        agent_width = int(0.05 * scale)

        #gfxdraw.hline(self.surf, 0, self.screen_w, self.screen_h//2, (0, 0, 0))
        
        if type(state["x"]) != float:
            x = state["x"][0]
            y = state["y"][0]
        else:
            x = state["x"]
            y = state["y"]

        # drawing agent
        X = int(x* scale +offset)
        Y = int(y* scale +offset)
        D = int(agent_width / 2)
        
        try:
            gfxdraw.aacircle(self.surf, X, Y, D, (255, 0, 0))
            gfxdraw.filled_circle(self.surf, int(x* scale+offset), int(y* scale +offset), int(agent_width / 2), (255, 0, 0))
        except:
            print(X,Y,D)
        # drawing center
        gfxdraw.aacircle(self.surf, offset, self.screen_h//2, int(agent_width / 4), (0, 0, 0))
       

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        
        self.clock.tick(render_fps)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()

    def get_obskeys(self):
        return ["x","y"]


if __name__ == "__main__":
    # Tests on environnement
    env = Simple2DEnv()    
    check_env_specs(env)
    
    print("\n* observation_spec:", env.observation_spec)
    print("\n* action_spec:", env.action_spec)
    print("\n* reward_spec:", env.reward_spec)
    print("\n* random action: \n", env.action_spec.rand())
    td = env.reset()
    print("\n* reset",td)
    td = env.rand_step(td)
    print("\n* random step tensordict", td)

    rollout = env.rollout(3)
    print("\n* rollout of three steps:", rollout)
    print("\n* Shape of the rollout TensorDict:", rollout.batch_size)
    
    batch_size = 10   
    td = env.reset(env.gen_params(batch_size=[batch_size]))
    print("\n* reset (batch size of 10)", td)
    env.close()    

So I want to use a 2 dimension action (1 for x, 1 for y) and also a 2 dimension reward too… and there I have a problem and need help !

I have tried to change shapes (to 2) in _make_spec for action_spec and reward_spec, then in _step make this changes:

x = tensordict[“x”]
y = tensordict[“y”]
print(“action:”,tensordict[“action”])

    u = tensordict["action"].squeeze(-1)
    print("squeezed:",u)
    u = u.clamp(-1, 1)
    ux = u[0]
    uy = u[1]

    # No trigo now
            
    rewardx =  -torch.abs(x+ux)**2
    rewardy =  -torch.abs(y+uy)**2
    rewardx.unsqueeze_(0)
    rewardy.unsqueeze_(0)
    reward = torch.cat((rewardx,rewardy), dim = 0)        
    print("reward:",reward)


but when testing the print give
action: tensor([-0.0008, -0.0007])
squeezed: tensor([-0.0008, -0.0007])
reward: tensor([-6.9054e-05, -2.8698e-01])
and finally an error on check_env_specs(env)=…
File “/home/fauche/.local/lib/python3.9/site-packages/torchrl/envs/common.py”, line 1331, in step
next_tensordict = self._step_proc_data(next_tensordict)
File “/home/fauche/.local/lib/python3.9/site-packages/torchrl/envs/common.py”, line 1423, in _step_proc_data
self._complete_done(self.full_done_spec, next_tensordict_out)
File “/home/fauche/.local/lib/python3.9/site-packages/torchrl/envs/common.py”, line 1360, in _complete_done
data.set(key, val.reshape(shape))
RuntimeError: shape ‘[1]’ is invalid for input of size 2

If someone can explain me where is the mistake… Help !

OK, I got it ! The problem was simply on updating the network to output 2 dimensions and some shape adaptations in _step … Here is the code:

from typing import Optional
import matplotlib.pyplot as plt

import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase 
from torchrl.envs.utils import check_env_specs

"""
A very simple objective: An actor appears at a random location xy on a plane and must join
the center of the plane...
"""

################
DEFAULT_X = 1
rendering = True
render_fps= 30
################

if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size=None):
    if batch_size is None:
        batch_size = []

    td = TensorDict({},[])
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

class Simple2DEnv(EnvBase):
    batch_locked = False
    def __init__(self, td_params=None, seed=None, device="cpu"):
        super().__init__(device=device, batch_size=[])
        
        if td_params is None:
            td_params = self.gen_params()
        
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        self.screen = None
        self.clock = None
        
    def _step(self,tensordict):
        x = tensordict["x"]
        y = tensordict["y"] 
        
        u = tensordict["action"].t()
        #u = u.clamp(-1, 1)
        
        ux = u[0]        
        uy = u[1]
                
        # No trigo now
                
        rewardx =  -torch.abs(x+ux)**2
        rewardy =  -torch.abs(y+uy)**2

        rew_x=rewardx.view(tensordict.shape,1)
        rew_y=rewardy.view(tensordict.shape,1)

        reward = torch.stack((rew_x,rew_y), dim = -1)
        reward.squeeze_()
        
        new_x = x + ux
        new_y = y + uy
        
        done = torch.zeros_like(rew_x, dtype=torch.bool)
                                     
        nextTD = TensorDict({"x":new_x,
                             "y":new_y,
                             "reward": reward,
                             "done": done},
                            tensordict.shape)

        if rendering:
            state = {"x":tensordict["x"].tolist(),"y":tensordict["y"].tolist()}
            self.render(state)
        return nextTD
    
    def _reset(self, tensordict):
        if tensordict is None:            
            tensordict = self.gen_params(batch_size=self.batch_size)

        high_x = torch.tensor(DEFAULT_X, device=self.device)
        low_x = -high_x

        x = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        y = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        
        out = TensorDict({"x":x, "y":y},
                         batch_size=tensordict.shape,)

        if rendering:
            state = {"x":out["x"].tolist(),"y":out["y"].tolist()}
            self.render(state)
        return out
    
    gen_params = staticmethod(gen_params)

    def _make_spec(self,td):        
        self.observation_spec = CompositeSpec(x=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              
                                              y=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              shape=())

        
        self.action_spec = BoundedTensorSpec(minimum = -DEFAULT_X/1000,
                                             maximum =  DEFAULT_X/1000,
                                             shape = (*td.shape,2),
                                             dtype = torch.float32)
        
        self.reward_spec = UnboundedContinuousTensorSpec(shape = (*td.shape,2,))

        
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def render(self,state):
        self.screen_w = 600
        self.screen_h = 600

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_w, self.screen_h))
            pygame.display.set_caption('View from first élément batch')
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((self.screen_w, self.screen_h))
        self.surf.fill((255, 255, 255))

        bound = DEFAULT_X
        scale = self.screen_w / (bound * 2)
        offset = self.screen_w // 2
        agent_width = int(0.05 * scale)
        
        if type(state["x"]) != float:
            x = state["x"][0]
            y = state["y"][0]
        else:
            x = state["x"]
            y = state["y"]

        # drawing agent
        X = int(x* scale +offset)
        Y = int(y* scale +offset)
        D = int(agent_width / 2)
        
        gfxdraw.aacircle(self.surf, X, Y, D, (255, 0, 0))
        gfxdraw.filled_circle(self.surf, X, Y, D, (255, 0, 0))

        # drawing center
        gfxdraw.aacircle(self.surf, offset, self.screen_h//2, int(agent_width / 4), (0, 0, 0))
       

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        
        self.clock.tick(render_fps)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()

    def get_obskeys(self):
        return ["x","y"]


if __name__ == "__main__":
    # Tests on environnement
    env = Simple2DEnv()    
    check_env_specs(env)
    
    print("\n* observation_spec:", env.observation_spec)
    print("\n* action_spec:", env.action_spec)
    print("\n* reward_spec:", env.reward_spec)
    print("\n* random action: \n", env.action_spec.rand())
    td = env.reset()
    print("\n* reset",td)
    td = env.rand_step(td)
    print("\n* random step tensordict", td)

    rollout = env.rollout(3)
    print("\n* rollout of three steps:", rollout)
    print("\n* Shape of the rollout TensorDict:", rollout.batch_size)
    
    batch_size = 10   
    td = env.reset(env.gen_params(batch_size=[batch_size]))
    print("\n* reset (batch size of 10)", td)
    env.close()    

And in the testing simple network code I just add after the definition of env:

action_shape = env.action_spec.shape[-1] 

and change the output of network as:

nn.LazyLinear(action_shape)

And for the fun the 3D version !

from typing import Optional
import matplotlib.pyplot as plt

import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase 
from torchrl.envs.utils import check_env_specs

"""
A simple objective: An actor appears at a random location xyz in 3D space and
must join the origin of the space...
"""

################
DEFAULT_X = 1
rendering = True
render_fps= 30
################

if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size=None):
    if batch_size is None:
        batch_size = []

    td = TensorDict({},[])
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

class Simple3DEnv(EnvBase):
    batch_locked = False
    def __init__(self, td_params=None, seed=None, device="cpu"):
        super().__init__(device=device, batch_size=[])
        
        if td_params is None:
            td_params = self.gen_params()
        
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        self.screen = None
        self.clock = None
        
    def _step(self,tensordict):
        x = tensordict["x"]
        y = tensordict["y"]
        z = tensordict["z"]
        
        u = tensordict["action"].t()
        #u = u.clamp(-1, 1)
        
        ux = u[0]        
        uy = u[1]
        uz = u[2]
                
        rewardx =  -torch.abs(x+ux)**2
        rewardy =  -torch.abs(y+uy)**2
        rewardz =  -torch.abs(z+uz)**2

        rew_x=rewardx.view(tensordict.shape,1)
        rew_y=rewardy.view(tensordict.shape,1)
        rew_z=rewardz.view(tensordict.shape,1)

        reward = torch.stack((rew_x,rew_y,rew_z), dim = -1)
        reward.squeeze_()

        new_x = x + ux
        new_y = y + uy
        new_z = z + uz
        
        done = torch.zeros_like(rew_x, dtype=torch.bool)
                                     
        nextTD = TensorDict({"x":new_x,
                             "y":new_y,
                             "z":new_z,
                             "reward": reward,
                             "done": done},
                            tensordict.shape)

        if rendering:
            state = {"x":tensordict["x"].tolist(),"y":tensordict["y"].tolist(),"z":tensordict["z"].tolist()}
            self.render(state)
        return nextTD
    
    def _reset(self, tensordict):
        if tensordict is None:            
            tensordict = self.gen_params(batch_size=self.batch_size)

        high_x = torch.tensor(DEFAULT_X, device=self.device)
        low_x = -high_x

        x = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        y = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        z = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
        out = TensorDict({"x":x, "y":y, "z":z},
                         batch_size=tensordict.shape,)

        if rendering:
            state = {"x":out["x"].tolist(),"y":out["y"].tolist(),"z":out["z"].tolist()}
            self.render(state)
        return out
    
    gen_params = staticmethod(gen_params)

    def _make_spec(self,td):        
        self.observation_spec = CompositeSpec(x=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              
                                              y=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),
                                              z=BoundedTensorSpec(minimum = -DEFAULT_X,
                                                                   maximum = DEFAULT_X,
                                                                   shape = (),
                                                                   dtype = torch.float32),

                                              shape=())

        
        self.action_spec = BoundedTensorSpec(minimum = -DEFAULT_X/1000,
                                             maximum =  DEFAULT_X/1000,
                                             shape = (*td.shape,3),
                                             dtype = torch.float32)
        
        self.reward_spec = UnboundedContinuousTensorSpec(shape = (*td.shape,3,))

        
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def render(self,state):
        self.screen_w = 600
        self.screen_h = 600

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_w, self.screen_h))
            pygame.display.set_caption('View from first élément batch')
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((self.screen_w, self.screen_h))
        self.surf.fill((255, 255, 255))

        bound = DEFAULT_X
        scale = self.screen_w / (bound * 2)
         
        offset = self.screen_w // 2
        agent_width = int(0.05 * scale)
        
        if type(state["x"]) != float:
            x = state["x"][0]
            y = state["y"][0]
            z = state["z"][0]
        else:
            x = state["x"]
            y = state["y"]
            z = state["z"]

        # drawing agent        
        X = int(x* scale +offset)
        Y = int(y* scale +offset)
        if z >1:z=1
        if z <-1:z=-1
        D = int(agent_width*(1-z) / 2)
        
        gfxdraw.aacircle(self.surf, X, Y, D, (255, 0, 0))
        gfxdraw.filled_circle(self.surf, X, Y, D, (255, 0, 0))

        # drawing center
        gfxdraw.aacircle(self.surf, offset, self.screen_h//2, int(agent_width / 4), (0, 0, 0))
       

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        
        self.clock.tick(render_fps)
        pygame.display.flip()

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()

    def get_obskeys(self):
        return ["x","y","z"]


if __name__ == "__main__":
    # Tests on environnement
    env = Simple3DEnv()    
    check_env_specs(env)
    
    print("\n* observation_spec:", env.observation_spec)
    print("\n* action_spec:", env.action_spec)
    print("\n* reward_spec:", env.reward_spec)
    print("\n* random action: \n", env.action_spec.rand())
    td = env.reset()
    print("\n* reset",td)
    td = env.rand_step(td)
    print("\n* random step tensordict", td)

    rollout = env.rollout(3)
    print("\n* rollout of three steps:", rollout)
    print("\n* Shape of the rollout TensorDict:", rollout.batch_size)
    
    batch_size = 10   
    td = env.reset(env.gen_params(batch_size=[batch_size]))
    print("\n* reset (batch size of 10)", td)
    env.close()    

Here is my first Multi-Agents Environment,I encountered quite a few difficulties, particularly with formatting observations. Indeed at the start I wanted to separate the coordinates of the agent and the target in a CompositeSpec type observation but this generated quite a few problems (if the details interest anyone…) so I resigned myself to putting the set of observations in a single tensor which I then cut into step to recover the data.
Simple3D_NAgentsEnv.py:

import torch
print(torch.__version__)
import torchrl
print(torchrl.__version__)

from torchrl.envs import EnvBase
from typing import Optional
from tensordict.tensordict import TensorDict
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec,DiscreteTensorSpec
from torchrl.envs.utils import check_env_specs

"""
objective: 3 actors appears at a random location in 3D space and must join
a target in a random 3D loction...
"""

################
DEFAULT_X = 1
rendering = True
render_fps= 30
################

if rendering:
    import pygame
    from pygame import gfxdraw

def gen_params(batch_size = None):
    if batch_size is None:
        batch_size = []
    td = TensorDict({},[])
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td


class SimpleNagentsEnv(EnvBase):
    batch_locked = False

    def __init__(self,batch_size = [], seed = None, device = 'cpu'):        
        td = self.gen_params(batch_size = batch_size)
        super().__init__(device=device, batch_size=batch_size)
        
        self.n_agents = 3
        self._make_spec()
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)
        
        self.screen = None
        self.clock = None

    gen_params = staticmethod(gen_params)
    
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def _step(self,tensordict):
        
        action = tensordict.get(("agents", "action"))  # [bs,nag,nac] 
        obs = tensordict.get(("agents", "observation"))# [bs,nag,nob] 
        
        agent_tds = []
        for i in range(self.n_agents):
            agent_action = action[:, i, ...]  #[bs,nac]
            agent_obs    = obs[:, i, ...]     #[bs,nob]
            
            xyz = agent_obs[:, 0:3, ...] # agent coords
            abc = agent_obs[:, 3:6, ...] # target coords
            
            new_xyz = xyz + agent_action   # [bs,nob]

            agent_rew     = -torch.abs(xyz + agent_action - abc)**2  # [bs,nrew]

            new_agent_obs =  torch.cat((new_xyz,abc),dim = -1)

            agent_td = TensorDict(
                source={
                    "observation": new_agent_obs,
                    "reward": agent_rew,
                },
                batch_size=self.batch_size,
                device=self.device,
            )
            agent_tds.append(agent_td)
        
        agent_tds = torch.stack(agent_tds, dim=1).to_tensordict()

        dones = torch.zeros(tensordict.shape, dtype = torch.bool)  
        
        nextTD = TensorDict(
            source={"agents": agent_tds, "done": dones, "terminated": dones.clone()},
            batch_size=self.batch_size,
            device=self.device,
        )

        if rendering:
            state = {"obs":nextTD["agents"]["observation"][0]}
            self.render(state)
        
        return nextTD
    
    def _reset(self,tensordict):
        if tensordict is None:
            tensordict = self.gen_params(batch_size=self.batch_size)
            self.batch_size = tensordict.shape
            
        high_x = torch.tensor(1, device=self.device)
        low_x = -high_x

        # Random target coords    
        a = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x/2 - low_x/2) + low_x/2).unsqueeze(-1)
        b = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x/2 - low_x/2) + low_x/2).unsqueeze(-1)
        c = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x/2 - low_x/2) + low_x/2).unsqueeze(-1)
        
        agent_tds = []
        for i in range(self.n_agents):
            # Random agent coords
            x = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x).unsqueeze(-1)
            y = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x).unsqueeze(-1)
            z = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x).unsqueeze(-1)

            obs = torch.stack((x,y,z,a,b,c),dim = -1).squeeze()
            
            agent_td = TensorDict(source={"observation": obs}, # [bs,nob]                                        
                                  batch_size= self.batch_size, 
                                  device=self.device,)
            
            agent_tds.append(agent_td)

        agent_tds = torch.stack(agent_tds, dim=1)

        agent_tds = agent_tds.to_tensordict() 
        
        done = torch.zeros(tensordict.shape, dtype=torch.bool)
        
        resetTD = TensorDict({"agents": agent_tds,
                              "done": done},
                             batch_size = tensordict.shape,
                             device=self.device)
        
        if rendering:
            state = {"obs":resetTD["agents"]["observation"][0]}
            self.render(state)

        return resetTD

    def _make_spec(self):
        agent =[{}]*self.n_agents

        action_specs = []
        observation_specs = []
        reward_specs = []
        
        for i in range(self.n_agents):
            agent[i]["action_spec"] = BoundedTensorSpec(low = -DEFAULT_X/100,
                                                        high = DEFAULT_X/100,
                                                        shape = (3),
                                                        dtype = torch.float32 )
                                                                                                                            
            agent[i]["reward_spec"] = UnboundedContinuousTensorSpec(shape = (3))
                                                                
            agent[i]["observation_spec"]  = BoundedTensorSpec(low = -DEFAULT_X,
                                                              high = DEFAULT_X,
                                                              shape=(6),
                                                              dtype=torch.float32)
                    
            
            action_specs.append(agent[i]["action_spec"])
            reward_specs.append(agent[i]["reward_spec"])
            observation_specs.append(agent[i]["observation_spec"])

        self.unbatched_action_spec = CompositeSpec(
            {"agents":CompositeSpec(
                {"action": torch.stack(action_specs,dim = 0)}, shape = (self.n_agents,)
                )
             })
        self.unbatched_reward_spec = CompositeSpec(
            {"agents":CompositeSpec(
                {"reward": torch.stack(reward_specs,dim = 0)}, shape = (self.n_agents,)
                )
             })
        self.unbatched_observation_spec = CompositeSpec(
            {"agents":CompositeSpec(
                {"observation": torch.stack(observation_specs,dim = 0)}, shape = (self.n_agents,)
                )
             })
        
        self.unbatched_done_spec = DiscreteTensorSpec(n = 2,
                                                      shape = torch.Size((1,)),
                                                      dtype = torch.bool)
        
        self.action_spec = self.unbatched_action_spec.expand(
            *self.batch_size, *self.unbatched_action_spec.shape
        )
        self.observation_spec = self.unbatched_observation_spec.expand(
            *self.batch_size, *self.unbatched_observation_spec.shape
        )
        self.reward_spec = self.unbatched_reward_spec.expand(
            *self.batch_size, *self.unbatched_reward_spec.shape
        )
        self.done_spec = self.unbatched_done_spec.expand(
            *self.batch_size, *self.unbatched_done_spec.shape
        )

    def render(self,state):
        self.screen_w = 600
        self.screen_h = 600

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.screen_w, self.screen_h))
            pygame.display.set_caption('View of the first element of the batch')
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((self.screen_w, self.screen_h))
        self.surf.fill((255, 255, 255))

        bound = DEFAULT_X
        scale = self.screen_w / (bound * 2)
        offset = self.screen_w // 2
        agent_width = int(0.07 * scale)

        for i in range(self.n_agents):
            x = state["obs"][i][0].item()
            y = state["obs"][i][1].item()
            z = state["obs"][i][2].item()
            if z >0.9:z=0.9
            if z <-0.9:z=-0.9
            D = int(agent_width*(1-z) / 2)            
            col = int(i*255/self.n_agents)
            
            # drawing agents    
            gfxdraw.aacircle(self.surf, int(x* scale )+offset, int(y* scale )+offset, D, (255, col, 0))
            gfxdraw.filled_circle(self.surf, int(x* scale)+offset, int(y* scale )+offset, D, (255, col, 0))

        # drawing center(just the last one because the 3 agents have the same Target)
        Tx = state["obs"][i][3].item()
        Ty = state["obs"][i][4].item()
        Tz = state["obs"][i][5].item()
        if Tz >0.9:Tz=0.9
        if Tz <-0.9:Tz=-0.9
        D = int(agent_width*(1-Tz) / 2)
        
        gfxdraw.aacircle(self.surf, int(Tx* scale)+offset, int(Ty* scale )+offset, D, (0, 0, 0))
       
        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        
        self.clock.tick(render_fps)
        pygame.display.flip()
        
    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()
        
             
if __name__ == "__main__":
    env = SimpleNagentsEnv(batch_size = [5])
    
    print("\n*action_spec:", env.full_action_spec)
    print("\n*reward_spec:", env.full_reward_spec)
    print("\n*done_spec:", env.full_done_spec)
    print("\n*observation_spec:", env.observation_spec)
    
    print("\n-action_keys:", env.action_keys)
    print("\n-reward_keys:", env.reward_keys)
    print("\n-done_keys:", env.done_keys)
    check_env_specs(env)

    print("\n* random action: \n", env.action_spec.rand())

    n_rollout_steps = 3
    rollout = env.rollout(n_rollout_steps)
    print("rollout of three steps:", rollout)
    print("Shape of the rollout TensorDict:", rollout.batch_size)
    print("ROOT:", rollout.exclude("next"))
    print("NEXT:",rollout.get("next"))
    env.close()

And code for testing with a Simple policy:
TestMASP.py:

# ref https://pytorch.org/rl/reference/envs.html#multi-agent-environments
# Torch
import torch

# Tensordict modules
from tensordict.nn import TensorDictModule

# Env
from torchrl.envs import TransformedEnv 
from torchrl.envs.utils import check_env_specs

# Multi-agent network
from torchrl.modules import MultiAgentMLP

# Loss
from torchrl.objectives import ClipPPOLoss

# Utils
from matplotlib import pyplot as plt
from collections import defaultdict

torch.manual_seed(0)

#####################################
# HYPERPARAMETRES
#####################################

# Devices
device = "cpu" 

# Training
lr = 3e-4  
max_grad_norm = 1.0 

#####################################
# ENVIRONNEMENT
#####################################

import Simple3D_NAgentsEnv
env = Simple3D_NAgentsEnv.SimpleNagentsEnv(batch_size = [16])

max_steps = 100  # Episode steps before done
          
# Test
check_env_specs(env)

#####################################
# POLICY
#####################################

net = torch.nn.Sequential(
    MultiAgentMLP(
        n_agent_inputs = env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs = env.action_spec.shape[-1],  
        n_agents = env.n_agents,
        centralised = False,
        share_params = False,
        device = device,
        depth = 2,
        num_cells = 64,
        activation_class = torch.nn.Tanh,
    ),
)

# wrap the neural network in a TensorDictModule
policy = TensorDictModule(net,
                          in_keys=[("agents", "observation")],
                          out_keys=[("agents", "action")])

print("\nRunning policy:", policy(env.reset()))

optim = torch.optim.Adam(policy.parameters(), lr)

#####################################
# TRAINING LOOP
#####################################

frames = 16_000
iterations = frames // env.batch_size[0]

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, frames)
logs = defaultdict(list)

for i in range(iterations):
    init_td = env.reset()
    rollout = env.rollout(max_steps, policy, tensordict=init_td, auto_reset=False)
    traj_return = rollout["next", "agents","reward"].mean()
    (-traj_return).backward()
    gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    progres = 100*i/iterations
    print(f"{progres: 3.1f}%, reward: {traj_return: 4.4f}, "
          f"last reward: {rollout[..., -1]['next', 'agents', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}")
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "agents", "reward"].mean().item())
    logs["grad_norm"].append(gn)
    scheduler.step()

env.close()
def plot():
    import matplotlib    
    from matplotlib import pyplot as plt
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.plot(logs["return"])
    plt.title("returns")
    plt.xlabel("iteration")

    plt.subplot(1, 3, 2)
    plt.plot(logs["last_reward"])
    plt.title("last reward")
    plt.xlabel("iteration")

    plt.subplot(1, 3, 3)
    plt.plot(logs["grad_norm"])
    plt.title("grad_norm")
    plt.xlabel("iteration")

    plt.show()
plot()

Hello @Jean-Marc_Fauche
Sorry that it took me so long to chime in here.
Can you summarize in a few lines what you’re trying to do and what problem you’re facing?
I’m not sure I understand what is blocking you atm