Using buffers in ParallelEnvs / MultiSyncCollectors

Hi, I’ve been moving a custom environment from dynamically sized to statically known observation sizes and I came across unexpected behavior in how the input and output TensorDicts in _step() are behaving.

What is the expected convention for how next = _step(td) should read and write its observations? Am I doing anything wrong?

Based the documentation I expected the following:

  • “td” contains the observations from the last output in the trajectory
  • “next” should be this last output

I observe this when running “use_buffers=False” (or with dynamically sized input as this also disables the buffers).

But when running with “buffers=True” every step is reading from the initial reset.

I created a minimal example that only takes in and increments the last observations.

Testing ParallelEnv
========================================
No buffers (expected behavior)
Input:  tensor(1)
Output:  tensor(2)
Input:  tensor(2)
Output:  tensor(3)
Input:  tensor(3)
Output:  tensor(4)
========================================
Buffers (unexpected behavior)
Input:  tensor(1)
Output:  tensor(2)
Input:  tensor(1)
Output:  tensor(2)
Input:  tensor(1)
Output:  tensor(2)

Code to reproduce:

import torch
from typing import Optional
from torchrl.envs import EnvBase
from torchrl.data import Composite, TensorSpec, Unbounded, Binary, Bounded
from torchrl.envs import check_env_specs
from tensordict import TensorDict


class TestEnv(EnvBase):
    def __init__(self, seed: int = 0, device="cpu"):
        super().__init__(device=device)

        self.observation_spec = self._create_observation_spec()
        self.action_spec = self._create_action_spec()
        self.reward_spec = self._create_reward_spec()
        self.done_spec = Binary(shape=(1,), device=self.device, dtype=torch.bool)

    def _create_observation_spec(self) -> TensorSpec:
        obs = Unbounded(shape=(1,), device=self.device, dtype=torch.int64)
        comp = Composite(value=obs)
        comp = Composite(observation=comp)
        return comp

    def _create_action_spec(self) -> TensorSpec:
        out = Bounded(
            shape=(1,),
            device=self.device,
            dtype=torch.int32,
            low=torch.tensor(0, device=self.device),
            high=torch.tensor(4, device=self.device),
        )
        out = Composite(action=out)
        return out

    def _create_reward_spec(self) -> TensorSpec:
        return Unbounded(shape=(1,), device=self.device, dtype=torch.float32)

    def _get_observation(self, v) -> TensorDict:
        obs = TensorDict(value=torch.tensor([0], device=self.device, dtype=torch.int64))
        obs["value"][0] = v + 1
        obs = TensorDict(observation=obs)
        return obs

    def _step(self, td: TensorDict) -> TensorDict:
        print("Input: ", td["observation"]["value"][0])
        out = self._get_observation(td["observation"]["value"][0])
        print("Output: ", out["observation"]["value"][0])
        out.set("reward", torch.tensor([1], device=self.device, dtype=torch.float32))
        out.set("done", torch.tensor([False], device=self.device, dtype=torch.bool))
        return out

    def _reset(self, td: Optional[TensorDict] = None) -> TensorDict:
        obs = self._get_observation(0)

        return obs

    def _set_seed(self, seed: Optional[int] = None):
        rng = torch.manual_seed(seed)
        self.rng = rng


def make_env():
    return TestEnv()


if __name__ == "__main__":
    from torchrl.envs import ParallelEnv

    print("Testing ParallelEnv")
    print("========================================")
    print("No buffers")
    workers = 1
    penv = ParallelEnv(
        workers,
        [make_env for _ in range(workers)],
        use_buffers=False,
    )
    r = penv.rollout(3)

    print("========================================")
    print("Buffers")

    workers = 1
    penv = ParallelEnv(
        workers,
        [make_env for _ in range(workers)],
        use_buffers=True,
    )
    r = penv.rollout(3)

Also is there a way to write to the preallocated memory “in place” to avoid temporary allocations for new observations at each step (before they are copied into the buffers)?

I think it might be that ‘td’ is only guaranteed to have an updated action field during _step(td)?

Although, even with only using action information and not accessing any other fields of ‘td’ MultiSyncDataCollector fails on this example as buffers is an empty dictionary in:
“workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()” (torchrl/collectors/collectors.py: 2378), so something else is probably going on?