MultiDiscrete Observation Causes Shape Mismatch

torchrl version: 0.5.0
python: 3.11.10

Error:

RuntimeError: All input tensors (value, reward and done states) must share a unique shape).

To Replicate:
Follow the PPO tutorial and make the following changes:

  1. Replace the GymEnv with this
class DummyEnv(gym.Env):
    def __init__(self, env_config={}, render_mode=None):
        self.action_space = Discrete(3)
        self.observation_space = MultiDiscrete([10,10])

    def reset(self, actions):
        return self.observation_space.sample(), {}

    def step(self, actions):
        return self.observation_space.sample(), 0, False, False, {}

And change the environment transform to:

env = TransformedEnv(
    base_env,
    Compose(
    DTypeCastTransform(dtype_in=torch.int64, dtype_out=torch.float32), #Necessary since we have discrete observation, observation space returns ints, but torch needs them to be floats
    StepCounter()
    )
)

My assumption is torch is thinking that the extra dimension is a batch size. Is there a way to flatten the MultiDiscrete space?