Hi,
I’ve defined a custom mult-agent environment. It passes check_env_specs
and I can run the following without issue on the base environment:
td = env.reset()
td = policy_modules['player_0'](td)
td = policy_modules['player_1'](td)
env.step(td)
however, when I chain in a RewardSum
transform:
env = TransformedEnv(
base_env,
Compose(
long_to_float_transform,
RewardSum(
in_keys=base_env.reward_keys,
reset_keys=["_reset"] * len(base_env.group_map.keys()),
)
)
)
as in the tutorial on Multi-agent environments Competitive Multi Agent DDPG, the env.step(...)
fails with the following error (32 is the number of environments).
E RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([32, 1]) and value.shape=torch.Size([32, 32, 1]).
For some reason, the leading batch size dimension is getting duplicated somewhere along the way. Any suggestions on why this is happening? The entire output is below:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
TorchAgent/CPMLTorchTraining.py:165: in main
env.step(td)
.conda/lib/python3.11/site-packages/torchrl/envs/common.py:1461: in step
next_tensordict = self._step(tensordict)
.conda/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:789: in _step
next_tensordict = self.transform._step(tensordict, next_tensordict)
.conda/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:1074: in _step
next_tensordict = t._step(tensordict, next_tensordict)
.conda/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py:5333: in _step
next_tensordict.set(out_key, prev_reward + reward)
.conda/lib/python3.11/site-packages/tensordict/base.py:2651: in set
return self._set_tuple(
.conda/lib/python3.11/site-packages/tensordict/_td.py:1911: in _set_tuple
td._set_tuple(
.conda/lib/python3.11/site-packages/tensordict/_td.py:1896: in _set_tuple
return self._set_str(
.conda/lib/python3.11/site-packages/tensordict/_td.py:1855: in _set_str
value = self._validate_value(value, check_shape=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = TensorDict(
fields={
available_actions: Tensor(shape=torch.Size([32, 1, 64, 52]), device=cuda:0, dtype=tor...uda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([32, 1]),
device=None,
is_shared=False)
value = tensor([[[0.],
[0.],
[0.],
...,
[0.],
[0.],
[0.]],
[[0.... [[0.],
[0.],
[0.],
...,
[0.],
[0.],
[0.]]], device='cuda:0')
def _validate_value(
self,
value: CompatibleType | dict[str, CompatibleType],
*,
check_shape: bool = True,
) -> CompatibleType | dict[str, CompatibleType]:
cls = type(value)
is_tc = None
if issubclass(cls, dict):
value = self._convert_to_tensordict(value)
is_tc = True
elif not issubclass(cls, _ACCEPTED_CLASSES):
try:
value = self._convert_to_tensor(value)
except ValueError as err:
raise ValueError(
f"TensorDict conversion only supports tensorclasses, tensordicts,"
f" numeric scalars and tensors. Got {type(value)}"
) from err
batch_size = self.batch_size
check_shape = check_shape and self.batch_size
if (
check_shape
and batch_size
and _shape(value)[: self.batch_dims] != batch_size
):
# if TensorDict, let's try to map it to the desired shape
if is_tc is None:
is_tc = _is_tensor_collection(cls)
if is_tc:
# we must clone the value before not to corrupt the data passed to set()
value = value.clone(recurse=False)
value.batch_size = self.batch_size
else:
> raise RuntimeError(
f"batch dimension mismatch, got self.batch_size"
f"={self.batch_size} and value.shape={_shape(value)}."
)
E RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([32, 1]) and value.shape=torch.Size([32, 32, 1]).
.conda/lib/python3.11/site-packages/tensordict/base.py:5805: RuntimeError