Issues with PPO Tutorial and Custom Dictionary Observation Space

I am learning how to use TorchRL and am trying to implement my own environment while following the PPO Tutorial. I even used TorchRL’s own Gymnasium Conversion Examples

My new environment is defined as the following

class MyEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Dict(
            obs0=spaces.Box(-1, 1, (2,)),
            obs1=spaces.Box(-1, 1, (3,))
        )
        self.action_space = spaces.Box(-1, 1, (1,))

    def step(self, action):
        return self.observation_space.sample(), 1, False, False, {}

    def reset(self, **kwargs):
        return self.observation_space.sample(), {}

I then follow the PPO tutorial and change this line from

base_env = GymEnv("InvertedDoublePendulum-v4", device=device)

to

base_env = GymWrapper(MyEnv())

It seems to run fine until this line

print("Running policy:", policy_module(env.reset()))

but get this error:
"KeyError: “Some tensors that are necessary for the module call may not have been found in the input tensordict: the following inputs are None: {“observation”}.”

This makes sense because after calling env.reset(), observation is not a key. I then updated the observation space to this

 self.observation_space = spaces.Dict(
    observation(
            obs0=spaces.Box(-1, 1, (2,)),
            obs1=spaces.Box(-1, 1, (3,))
        ))

but it returns this error

...
File /home/myuser/.../site-packages/torch/nn/modules/linear.py", line 259, in itialize_parameters
    self.in_features = input.shape[-1]
IndexError: tuple index out of range

For some reason the input is still in TensorDict format. All other tutorials using gym environments show the input data as a tensor. While a tensor has a shape, the TensorDict’s shape is torch.Size().

I have solved this using the CatTensors class.

Following the PPO tutorial, I added the CatTensors transform to the Compose object.

env = TransformedEnv(
    base_env,
    Compose(
        # normalize observations
        CatTensors(in_keys=["obs0", "obs1", out_key="observation])
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)
1 Like