Defining a ProbalisticActor with two normal distributions

Hi All,

Just starting using pyTorch and TorchRL, I am trying to adapt the PPO tutorial to my own settings.

I want to train robots which have a mixture of wheels and legs. So, I have a composite actor spec with joints and wheels output.
But I fail to define a working probabilistic actor.

My actor_spec:

        self.action_spec = CompositeSpec(
            joints=BoundedTensorSpec(
                low=-torch.pi,
                high=torch.pi,
                shape=(len(self.robot.joint_ids),),
                dtype=torch.float32,
                device=self.device,
                ),
            wheels=BoundedTensorSpec(
                low=-100,
                high=100,
                shape=(len(self.robot.wheel_ids),),
                dtype=torch.float32,
                device=self.device,
           ) 
        )

And the definition of my actor:

actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec["joints"].shape[-1] + 2 * env.action_spec["wheels"].shape[-1], device=device),
    NormalParamExtractor(),
)

# Define the policy module
policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], 
    out_keys=["joints_loc", "joints_scale", "wheels_loc", "wheels_scale"]
)


actor = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["joints_loc", "joints_scale", "wheels_loc", "wheels_scale"],
    out_keys=["joints", "wheels"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": torch.cat([env.action_spec["joints"].space.low, env.action_spec["wheels"].space.low]),
        "max": torch.cat([env.action_spec["joints"].space.high, env.action_spec["wheels"].space.high]),
    },
)

When I run actor(env.reset()) I get this error:

KeyError: 'key "wheels_loc" not found in TensorDict with keys [\'done\', \'joints_loc\', \'joints_scale\', \'observation\', \'step_count\', \'terminated\']'

It seems that it does only take into account the two first out_keys of my module.

How would you implement a probabilistic actor with a composite action output? Or how to define a probabilistic actor with two normal distribution?

Many thanks!

Use a CompositeDistribution:

from torch import distributions as d
from tensordict import TensorDict
import torch
from tensordict.nn import CompositeDistribution

params = TensorDict({
     "joint": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
     "wheel": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
 }, [3])
dist = CompositeDistribution(params,
    distribution_map={"joint": d.Normal, "wheel": d.Normal})
sample = dist.sample((4,))
print(sample)

LMK if you need further help!

Thank you for the quick answer.
So I integrated a CompositeDistribution following your suggestion:

actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec["joints"].shape[-1] + 2 * env.action_spec["wheels"].shape[-1], device=device),
    NormalParamExtractor(),
)

params = TensorDict({
    "joints": {"loc": torch.rand(env.action_spec["joints"].shape[-1]), "scale": torch.rand(env.action_spec["wheels"].shape[-1])},
    "wheels": {"loc": torch.rand(env.action_spec["joints"].shape[-1]), "scale": torch.rand(env.action_spec["wheels"].shape[-1])}},
    batch_size=[],
)


# Define the policy module
policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], 
    out_keys=[("joints","loc"), ("joints"),("scale"), ("wheels","loc"), ("wheels","scale")]
)


actor = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=[("joints","loc"), ("joints"),("scale"), ("wheels","loc"), ("wheels","scale")],
    out_keys=["joints", "wheels"],

    distribution_class=CompositeDistribution,
    distribution_kwargs={
        "params" : params,
        "distribution_map": {"joints": d.Normal, "wheels": d.Normal},
        "min": torch.cat([env.action_spec["joints"].space.low, env.action_spec["wheels"].space.low]),
        "max": torch.cat([env.action_spec["joints"].space.high, env.action_spec["wheels"].space.high]),
    },
)

Now I have this error:

AttributeError: 'BoundedTensorSpec' object has no attribute 'items'

I guess it comes from the action_spec and how I write my out_keys for the policy_module and in_keys for the ProbabilisticActor.

I am not sure how to write them the right way.

Here is the corrected code

import torch
from torchrl.data import ReplayBuffer, LazyMemmapStorage
from torch import distributions as d
from tensordict.nn import NormalParamExtractor, TensorDictModule as Mod, TensorDictSequential as Seq, ProbabilisticTensorDictModule as Prob, CompositeDistribution
from tensordict import TensorDict
from torch import nn
num_cells = 32
device = "cpu"
num_joints = 4
num_wheels = 3

actor_net_base = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
)
actor_net_joints = nn.Sequential(
    nn.LazyLinear(num_cells, 2*num_joints, device=device),
    NormalParamExtractor(),
)
actor_net_wheels = nn.Sequential(
    nn.LazyLinear(num_cells, 2*num_wheels, device=device),
    NormalParamExtractor(),
)
actor_module = Seq(
    Mod(actor_net_base, in_keys=["obs"], out_keys=["embed"]),
    Mod(actor_net_joints, in_keys=["embed"], out_keys=[("params", "joints", "loc"), ("params", "joints", "scale")]),
    Mod(actor_net_joints, in_keys=["embed"], out_keys=[("params", "wheels", "loc"), ("params", "wheels", "scale")]),
)

# Define the policy module

actor = Seq(
    actor_module,
    Prob(
    in_keys=["params"],
    out_keys=["joints", "wheels"],
    distribution_class=CompositeDistribution,
    distribution_kwargs={
        "distribution_map": {"joints": d.Normal, "wheels": d.Normal},
    },
)
)
# With a tensordict as input
print(actor(TensorDict(obs=torch.randn(1, 3))))
# With a tensor as input
print(actor(obs=torch.randn(1, 3)))

You need to put the params in a container (e.g., “params” in our case) because otherwise you’ll override the "joints" and "wheels" entries during sampling.

Thank you for this solution and explanation!
It is bit more clear now.
So I have implemented the ProbabilisticActor as follow:

actor = ProbabilisticActor(
    module=actor_module,
    spec=env.action_spec,
    in_keys=["params"],
    out_keys=["joints", "wheels"],
    distribution_class=CompositeDistribution,
    distribution_kwargs={
        "distribution_map": {"joints": d.Normal, "wheels": d.Normal},
    },
    return_log_prob=True,
)

It works fine but the key “action” is missing which I believe is needed for the PPO training. I tried to add this line "name_map": {"joints": "action", "wheels": "action"},. But I get an unexpected keyword argument error on the CompositeDistribution.

Do you have a suggestion to solve this issue?

Thank you for all the help!

try this

policy = TensorDictSequential(
   policy,
   TensorDictModule(lambda x, y: torch.cat([x, y], -1), in_keys=["joints", "wheels"], out_keys=["action"])
)

It does not work because the distribution is then not accessible anymore.
I get this error:

AttributeError: 'TensorDictSequential' object has no attribute 'get_dist

Where are you getting that error?

I get this error at this line : loss_vals = loss_module(subdata.to(device))

My code is the same as the PPO tutorial of TorchRL.

Where is get_dist called?
Can you share more info, like the full error stack and the lines of code causing it?

Yes sure,
Here is the full error message:

Traceback (most recent call last):
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 1242, in __getattr__
    return super().__getattr__(name)
  File "/home/leni/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'TensorDictSequential' object has no attribute 'get_dist'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/leni/git/pyARE/ppo.py", line 252, in <module>
    loss_vals = loss_module(subdata.to(device))
  File "/home/leni/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/leni/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 291, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/objectives/ppo.py", line 762, in forward
    log_weight, dist = self._log_weight(tensordict)
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/objectives/ppo.py", line 474, in _log_weight
    dist = self.actor_network.get_dist(tensordict)
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 1249, in __getattr__
    raise err2 from err1
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 1247, in __getattr__
    return getattr(super().__getattr__("module"), name)
  File "/home/leni/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'ModuleList' object has no attribute 'get_dist'

Let’s backtrack for a second.
If you need a single action vector, why building two separate dists? Why not just concatenating the two loc and scale vectors?
We usually use two separate distributions if the actor expects two different components in the action but here my understanding is that you only need one, am I right?

Originally, I need two action vectors because I have two types of actuators: joints and wheels.
They don’t have the same bounds. Also, wheels are controlled on velocity while the joints are controlled on position.

Though, I could use one action vector with normalised values.

In TorchRL, the TanhNormal distribution accepts bounds that are not homogeneous. You can pass a vector distribution_kwargs={"low": torch.tensor([-2.5, -2.5, -1.0, -1.0]), "high": torch.tensor([2.5, 2.5, 1.0, 1.0])} and it will work just fine!

Ok, so, I modified the code to concatenate my joints and wheels in one loc and scale:

actor_module = TensorDictSequential(
    TensorDictModule(actor_net_base,in_keys=["observation"],out_keys=["embed"]),
    TensorDictModule(actor_net_joints,in_keys=["embed"],out_keys=[("params", "joints", "loc"), ("params", "joints", "scale")]),
    TensorDictModule(actor_net_wheels,in_keys=["embed"],out_keys=[("params", "wheels", "loc"), ("params", "wheels", "scale")]),
    TensorDictModule(lambda x, y, z, w: (torch.cat([x, y]), torch.cat([z,w])), 
                     in_keys=[("params", "joints", "loc"), ("params", "wheels", "loc"),("params", "joints", "scale"), ("params", "wheels", "scale")], 
                     out_keys=["loc","scale"]),
)

actor = ProbabilisticActor(
    module = actor_module,
    spec=env.action_spec,
    in_keys=["loc","scale"],
    out_keys=["joints", "wheels"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": torch.cat([env.action_spec["joints"].space.low, env.action_spec["wheels"].space.low]),
         "max": torch.cat([env.action_spec["joints"].space.high, env.action_spec["wheels"].space.high]),
    },
    return_log_prob=True,
)

I don’t know if this is the best way to concatenate tensors.

Then, when I call my collector I have the following error:

Traceback (most recent call last):                         
  File "/home/leni/git/pyARE/ppo.py", line 245, in <module> 
    for i, tensordict_data in enumerate(collector):
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 952, in iterator
    tensordict_out = self.rollout()
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/_utils.py", line 469, in unpack_rref_and_invoke_functio
n                            
    return func(self, *args, **kwargs)
  File "/home/leni/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1069, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/envs/common.py", line 2576, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/envs/common.py", line 1409, in step
    next_tensordict = self._step(tensordict)
  File "/home/leni/.local/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 738, in _step
    next_tensordict = self.base_env._step(tensordict_in)
  File "/home/leni/git/pyARE/rl_env.py", line 71, in _step
    wheel_actions = action["wheels"].tolist()
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/base.py", line 242, in __getitem__
    result = self._get_tuple(idx_unravel, NO_DEFAULT)
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/_td.py", line 1629, in _get_tuple
    first = self._get_str(key[0], default)
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/_td.py", line 1625, in _get_str
    return self._default_get(first_key, default)
  File "/home/leni/.local/lib/python3.10/site-packages/tensordict/base.py", line 2430, in _default_get
    raise KeyError(
KeyError: 'key "wheels" not found in TensorDict with keys [\'collector\', \'done\', \'embed\', \'joints\', \'loc\', \'
observation\', \'params\', \'sample_log_prob\', \'scale\', \'step_count\', \'terminated\']'

I don’t understand why I loose the “wheels” key.

The collector is defined as follow:

collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

I have decided to drop trying to have a composite action spec. I switched to one action tensor for the actor. Then, I split the action tensor into two tensors before sending the command to the simulator.

It seems to work now.

Thank you so much for all your help! I understand better how torch works now.

Yeah I think it’s the simplest option! Glad it works!