Hi, I’ve been moving a custom environment from dynamically sized to statically known observation sizes and I came across unexpected behavior in how the input and output TensorDicts in _step()
are behaving.
What is the expected convention for how next = _step(td)
should read and write its observations? Am I doing anything wrong?
Based the documentation I expected the following:
- “td” contains the observations from the last output in the trajectory
- “next” should be this last output
I observe this when running “use_buffers=False” (or with dynamically sized input as this also disables the buffers).
But when running with “buffers=True” every step is reading from the initial reset.
I created a minimal example that only takes in and increments the last observations.
Testing ParallelEnv
========================================
No buffers (expected behavior)
Input: tensor(1)
Output: tensor(2)
Input: tensor(2)
Output: tensor(3)
Input: tensor(3)
Output: tensor(4)
========================================
Buffers (unexpected behavior)
Input: tensor(1)
Output: tensor(2)
Input: tensor(1)
Output: tensor(2)
Input: tensor(1)
Output: tensor(2)
Code to reproduce:
import torch
from typing import Optional
from torchrl.envs import EnvBase
from torchrl.data import Composite, TensorSpec, Unbounded, Binary, Bounded
from torchrl.envs import check_env_specs
from tensordict import TensorDict
class TestEnv(EnvBase):
def __init__(self, seed: int = 0, device="cpu"):
super().__init__(device=device)
self.observation_spec = self._create_observation_spec()
self.action_spec = self._create_action_spec()
self.reward_spec = self._create_reward_spec()
self.done_spec = Binary(shape=(1,), device=self.device, dtype=torch.bool)
def _create_observation_spec(self) -> TensorSpec:
obs = Unbounded(shape=(1,), device=self.device, dtype=torch.int64)
comp = Composite(value=obs)
comp = Composite(observation=comp)
return comp
def _create_action_spec(self) -> TensorSpec:
out = Bounded(
shape=(1,),
device=self.device,
dtype=torch.int32,
low=torch.tensor(0, device=self.device),
high=torch.tensor(4, device=self.device),
)
out = Composite(action=out)
return out
def _create_reward_spec(self) -> TensorSpec:
return Unbounded(shape=(1,), device=self.device, dtype=torch.float32)
def _get_observation(self, v) -> TensorDict:
obs = TensorDict(value=torch.tensor([0], device=self.device, dtype=torch.int64))
obs["value"][0] = v + 1
obs = TensorDict(observation=obs)
return obs
def _step(self, td: TensorDict) -> TensorDict:
print("Input: ", td["observation"]["value"][0])
out = self._get_observation(td["observation"]["value"][0])
print("Output: ", out["observation"]["value"][0])
out.set("reward", torch.tensor([1], device=self.device, dtype=torch.float32))
out.set("done", torch.tensor([False], device=self.device, dtype=torch.bool))
return out
def _reset(self, td: Optional[TensorDict] = None) -> TensorDict:
obs = self._get_observation(0)
return obs
def _set_seed(self, seed: Optional[int] = None):
rng = torch.manual_seed(seed)
self.rng = rng
def make_env():
return TestEnv()
if __name__ == "__main__":
from torchrl.envs import ParallelEnv
print("Testing ParallelEnv")
print("========================================")
print("No buffers")
workers = 1
penv = ParallelEnv(
workers,
[make_env for _ in range(workers)],
use_buffers=False,
)
r = penv.rollout(3)
print("========================================")
print("Buffers")
workers = 1
penv = ParallelEnv(
workers,
[make_env for _ in range(workers)],
use_buffers=True,
)
r = penv.rollout(3)
Also is there a way to write to the preallocated memory “in place” to avoid temporary allocations for new observations at each step (before they are copied into the buffers)?