MultiDiscrete Observation Causes Shape Mismatch

torchrl version: 0.5.0
python: 3.11.10

Error:

RuntimeError: All input tensors (value, reward and done states) must share a unique shape).

To Replicate:
Follow the PPO tutorial and make the following changes:

  1. Replace the GymEnv with this
class DummyEnv(gym.Env):
    def __init__(self, env_config={}, render_mode=None):
        self.action_space = Discrete(3)
        self.observation_space = MultiDiscrete([10,10])

    def reset(self, actions):
        return self.observation_space.sample(), {}

    def step(self, actions):
        return self.observation_space.sample(), 0, False, False, {}

And change the environment transform to:

env = TransformedEnv(
    base_env,
    Compose(
    DTypeCastTransform(dtype_in=torch.int64, dtype_out=torch.float32), #Necessary since we have discrete observation, observation space returns ints, but torch needs them to be floats
    StepCounter()
    )
)

My assumption is torch is thinking that the extra dimension is a batch size. Is there a way to flatten the MultiDiscrete space?

Hello!
This code works fine on my end (with and without ParallelEnv) using torchrl nightly. I think it should work with v0.6 too!

import gym
import torch
from gym.spaces import Discrete, MultiDiscrete
from torchrl.envs import GymWrapper, TransformedEnv, StepCounter, Compose, DTypeCastTransform, ParallelEnv


class DummyEnv(gym.Env):
    def __init__(self, env_config={}, render_mode=None):
        self.action_space = Discrete(3)
        self.observation_space = MultiDiscrete([10, 10])

    def reset(self):
        return self.observation_space.sample(), {}

    def step(self, actions):
        return self.observation_space.sample(), 0, False, False, {}

if __name__ == "__main__":
    base_env = ParallelEnv(2, lambda:GymWrapper(DummyEnv()))

    env = TransformedEnv(
        base_env,
        Compose(
            # Necessary since we have discrete observation,
            #  observation space returns ints, but torch needs them to be floats
            DTypeCastTransform(in_keys=["observation"], out_keys=["observation"], dtype_in=torch.int64, dtype_out=torch.float32),
            StepCounter()
        )
    )

    env.check_env_specs()