RewardSum in custom multi agent env duplicating dimension

Hi,
I’ve defined a custom mult-agent environment. It passes check_env_specs and I can run the following without issue on the base environment:

td = env.reset()
td = policy_modules['player_0'](td)
td = policy_modules['player_1'](td)
env.step(td)

however, when I chain in a RewardSum transform:

    env = TransformedEnv(
        base_env,
        Compose(
            long_to_float_transform,
            RewardSum(
                in_keys=base_env.reward_keys,
                reset_keys=["_reset"] * len(base_env.group_map.keys()),
            )
        )
    )

as in the tutorial on Multi-agent environments Competitive Multi Agent DDPG, the env.step(...) fails with the following error (32 is the number of environments).

E               RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([32, 1]) and value.shape=torch.Size([32, 32, 1]).

For some reason, the leading batch size dimension is getting duplicated somewhere along the way. Any suggestions on why this is happening? The entire output is below:

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
TorchAgent/CPMLTorchTraining.py:165: in main
    env.step(td)
.conda/lib/python3.11/site-packages/torchrl/envs/common.py:1461: in step
    next_tensordict = self._step(tensordict)
.conda/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:789: in _step
    next_tensordict = self.transform._step(tensordict, next_tensordict)
.conda/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:1074: in _step
    next_tensordict = t._step(tensordict, next_tensordict)
.conda/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:5333: in _step
    next_tensordict.set(out_key, prev_reward + reward)
.conda/lib/python3.11/site-packages/tensordict/base.py:2651: in set
    return self._set_tuple(
.conda/lib/python3.11/site-packages/tensordict/_td.py:1911: in _set_tuple
    td._set_tuple(
.conda/lib/python3.11/site-packages/tensordict/_td.py:1896: in _set_tuple
    return self._set_str(
.conda/lib/python3.11/site-packages/tensordict/_td.py:1855: in _set_str
    value = self._validate_value(value, check_shape=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = TensorDict(
    fields={
        available_actions: Tensor(shape=torch.Size([32, 1, 64, 52]), device=cuda:0, dtype=tor...uda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([32, 1]),
    device=None,
    is_shared=False)
value = tensor([[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0....     [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]]], device='cuda:0')

    def _validate_value(
        self,
        value: CompatibleType | dict[str, CompatibleType],
        *,
        check_shape: bool = True,
    ) -> CompatibleType | dict[str, CompatibleType]:
        cls = type(value)
        is_tc = None
        if issubclass(cls, dict):
            value = self._convert_to_tensordict(value)
            is_tc = True
        elif not issubclass(cls, _ACCEPTED_CLASSES):
            try:
                value = self._convert_to_tensor(value)
            except ValueError as err:
                raise ValueError(
                    f"TensorDict conversion only supports tensorclasses, tensordicts,"
                    f" numeric scalars and tensors. Got {type(value)}"
                ) from err
        batch_size = self.batch_size
        check_shape = check_shape and self.batch_size
        if (
            check_shape
            and batch_size
            and _shape(value)[: self.batch_dims] != batch_size
        ):
            # if TensorDict, let's try to map it to the desired shape
            if is_tc is None:
                is_tc = _is_tensor_collection(cls)
            if is_tc:
                # we must clone the value before not to corrupt the data passed to set()
                value = value.clone(recurse=False)
                value.batch_size = self.batch_size
            else:
>               raise RuntimeError(
                    f"batch dimension mismatch, got self.batch_size"
                    f"={self.batch_size} and value.shape={_shape(value)}."
                )
E               RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([32, 1]) and value.shape=torch.Size([32, 32, 1]).

.conda/lib/python3.11/site-packages/tensordict/base.py:5805: RuntimeError

Hey!

Thanks for posting this!

Could you confirm that you are using the latest torchrl version?

In order for me to help, I would kindly ask you to strip down your environment to the bare minumum setup where the bug still occours and send it over in a script to reproduce.

Best

Matteo