Do TorchRL environments have a way to handle policies that outputs trajectories?

I’m working with policies that, given observations, output a fixed-horizon action trajectory. Once that horizon is consumed, we rinse and repeat. My current setup is to absorb that logic into the environment with a loop that executes multiple simulation steps. But this is not ideal as I end up having to aggregate the rewards somehow and these multiple steps look like 1 step.

What I’d like is that BaseEnv.rollout handles this logic by querying the policy once, consuming the resulting action trajectory with multiple steps in the environment (keeping track of the tensordict for each individual step), and then querying the policy again once the action trajectory is fully consumed.

I’m thinking of overriding BaseEnv.rollout to handle this but I’m wondering if I’ve missed a better way to do this. The closest thing I could find was the FrameSkipTransform but it’s not right for two reasons:

  1. It expects one action and runs that action multiple times (rather than a trajectory of distinct actions).
  2. It still aggregates the result of each step, therefore not correctly reporting on reward, and not correctly breaking out for done conditions.

We don’t have it yet but it has been asked already on GH and discord!
Since it’s becoming a pattern, we should probably make this a thing in the next release.
The main “issue” I have with this is that the logic of env.rollout is that we stack results, and the method as it is not doesn’t handle batches of steps coming all at once.

Would it work if the policy outputs multiple actions, and the env has a transform that consumes those actions one by one (by querying them from a buffer) and passes them to the env in rollout? Or do you need something more “batched” on the env side than that?

Thanks for picking this up so quickly. I’m just learning about transforms now, but it looks like the way to make a transform is to override the _step method, but then wouldn’t this come back to the issue of expecting a single “step” to return a single reward? And even if you make that return a list of rewards, then you’ll have a mismatch between the tensordict structure for the transform, vs that for the underlying environment.

If my interpretation is right, I’d think that having an inner loop in rollout would be more fitting.

Thoughts?

Again, I’m just learning TorchRL, so take my “feedback” with a grain of salt.

Yeah I can see how difficult it will be to make this fit in since we’ll also need for the policy to be put on hold while waiting for the env to consume the actions passed to it…

This seems to be doing the trick:

from torchrl.envs import Transform, GymEnv, TransformedEnv
import torch

from tensordict import assert_allclose_td
base_env = GymEnv("CartPole-v1")
base_env.set_seed(0)

rollout = base_env.rollout(100)

class BatchedActionTransform(Transform):
    def _inv_call(self, tensordict):
        parent = self.parent
        next_tds = []
        tensordict_unbind = tensordict.unbind(-1)
        for td in tensordict_unbind[:-1]:
            next_tds.append(parent._step(td))
        self.next_tds = next_tds
        return tensordict_unbind[-1]
    def _step(self, tensordict, next_tensordict):
        return torch.stack(self.next_tds + [next_tensordict], -1)

env = TransformedEnv(base_env, BatchedActionTransform())
env.set_seed(0)
env._batch_locked = False
env.base_env._batch_locked = False
env.batch_locked

one_big_step = env.step(rollout)

assert_allclose_td(rollout, one_big_step)

I need to patch the env to let it know that the batch size of the input tensordict does not need to match the env batch size

Got it. And what are your thoughts on this (adding an inner loop to the rollout):

def _rollout_stop_early(
    self,
    *,
    tensordict,
    auto_cast_to_device,
    max_steps,
    policy,
    policy_device,
    env_device,
    callback,
):
    """Override adds handling of multi-step policies."""
    tensordicts = []
    step_ix = 0
    do_break = False
    while not do_break:
        if auto_cast_to_device:
            if policy_device is not None:
                tensordict = tensordict.to(policy_device, non_blocking=True)
            else:
                tensordict.clear_device_()
        tensordict = policy(tensordict)
        if auto_cast_to_device:
            if env_device is not None:
                tensordict = tensordict.to(env_device, non_blocking=True)
            else:
                tensordict.clear_device_()

        actions = tensordict["action"].clone()
        if actions.ndim == 1:
            actions = actions.unsqueeze(0)
        elif actions.ndim > 2:
            raise RuntimeError("Expected actions to be (timesteps, action_dim) or (action_dim,) tensor")

        for action in actions:
            tensordict["action"] = action
            tensordict = self.step(tensordict)
            tensordicts.append(tensordict.clone(False))

            if step_ix == max_steps - 1:
                # we don't truncated as one could potentially continue the run
                do_break = True
                break
            tensordict = step_mdp(
                tensordict,
                keep_other=True,
                exclude_action=False,
                exclude_reward=True,
                reward_keys=self.reward_keys,
                action_keys=self.action_keys,
                done_keys=self.done_keys,
            )
            # done and truncated are in done_keys
            # We read if any key is done.
            any_done = _terminated_or_truncated(
                tensordict,
                full_done_spec=self.output_spec["full_done_spec"],
                key=None,
            )
            if any_done:
                break

            if callback is not None:
                callback(self, tensordict)

            step_ix += 1

    return tensordicts

That seems to be working yes!

That will work in some specific cases but in others people will have other inputs than actions (eg stateless environments or multi-agent with multiple action keys).

I like the idea of the transform for this (like the FrameSkipTransform) because it decouples the base env implementation from the ad-hoc features one may want.
For instance, if the logic of the transform isn’t accurate, it’s more straightforward to clone the transform and modify it than changing the rollout inner method.
I think it’s a “mistake” we made with GymLikeEnv where we implemented a built-in, fast skip-frame but now people are asking for extra features from it and it’s hard to make it all fit within the same class. Using a transform from the beginning would have given us more degrees of freedom.

If that works in your case, you could inherit from the env you’re using and change the corresponding method, although I’d be mindful that it’s a private method and it may or may not keep the same signature and behaviour in future releases.

One thing this solution has that mine didn’t is that it allows the input tensordict to have the right shape though. Not sure how to replicate that with a transform in its current state.

So overall I think that the options are:

  • Write a dedicated function for this like torchrl.envs.collect_batched_actions
  • Make a TransformedEnv (not Transform) subclass that does exactly this within its rollout. That way you’d just do
    env = BatchedActionsEnv(base_env)
    td = env.reset()
    td_with_batched_action = policy(td) # adds N actions to the td reset
    env.step(td_with_batched_action)  # gives back a tensordict of shape [N] containing all the actions and obs as if it had been called on each action independently
    

I like the second option, it’s modular, reusable and kinda obvious to use

Excellent! Well I’m going to think on this longer, while I get up to speed on TorchRL, and will ping back if I have thoughts. For the time being, thanks!