Batch size in Rollout

When performing env.rollout, it internally calls the policy module, with a sample at each step.
It is in the shape of ovservation, without any batch_size dimension.
To keep everything more consistent, is there a way to add a batch_size dimension of 1, instead of unsqueezing in the model itself?

Much appreciated!

IIUC you want an env with a leading dim of 1 on each tensor, is that right?
You should use this transform for this
https://pytorch.org/rl/stable/reference/generated/torchrl.envs.transforms.BatchSizeTransform.html?highlight=batchsizetransform

This should do the trick:

env = env.append_transform(BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0))
1 Like