Hello, I open this topic … for my contribution here is the simplest one you can imagine!
from typing import Optional
import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
"""
A very very simple objective: An actor appears at a random location x on a line and must join
the center of the line...
"""
################
DEFAULT_X = 1
rendering = True
render_fps= 30
################
if rendering:
import pygame
from pygame import gfxdraw
def gen_params(batch_size=None):
if batch_size is None:
batch_size = []
td = TensorDict({},[])
if batch_size:
td = td.expand(batch_size).contiguous()
return td
class SimpleEnv(EnvBase):
batch_locked = False
def __init__(self, td_params=None, seed=None, device="cpu"):
super().__init__(device=device, batch_size=[])
if td_params is None:
td_params = self.gen_params()
self._make_spec(td_params)
if seed is None:
seed = torch.empty((), dtype=torch.int64).random_().item()
self.set_seed(seed)
self.screen = None
self.clock = None
def _step(self,tensordict):
x = tensordict["x"]
u = tensordict["action"].squeeze(-1)
#u = u.clamp(-1, 1)
reward = -torch.abs(x+u)**2#100000
new_x = x + u
done = torch.zeros_like(reward, dtype=torch.bool)
nextTD = TensorDict({"x":new_x,
"reward": reward,
"done": done},
tensordict.shape)
if rendering:
state = {"x":tensordict["x"].tolist()}
self.render(state)
return nextTD
def _reset(self, tensordict):
if tensordict is None:
tensordict = self.gen_params(batch_size=self.batch_size)
high_x = torch.tensor(DEFAULT_X, device=self.device)
low_x = -high_x
x = (torch.rand(tensordict.shape, generator=self.rng, device=self.device) * (high_x - low_x) + low_x)
out = TensorDict({"x":x},
batch_size=tensordict.shape,)
if rendering:
state = {"x":out["x"].tolist()}
self.render(state)
return out
gen_params = staticmethod(gen_params)
def _make_spec(self,td):
self.observation_spec = CompositeSpec(x=BoundedTensorSpec(minimum = -DEFAULT_X,
maximum = DEFAULT_X,
shape = (),
dtype = torch.float32),
shape=())
#self.state_spec = self.observation_spec.clone()
self.action_spec = BoundedTensorSpec(minimum = -DEFAULT_X/100,
maximum = DEFAULT_X/100,
shape = (1,),
dtype = torch.float32)
self.reward_spec = UnboundedContinuousTensorSpec(shape = (*td.shape,1))
def _set_seed(self, seed: Optional[int]):
rng = torch.manual_seed(seed)
self.rng = rng
def render(self,state):
self.screen_w = 600
self.screen_h = 400
if self.screen is None:
pygame.init()
pygame.display.init()
self.screen = pygame.display.set_mode((self.screen_w, self.screen_h))
pygame.display.set_caption('View from first élément batch')
if self.clock is None:
self.clock = pygame.time.Clock()
self.surf = pygame.Surface((self.screen_w, self.screen_h))
self.surf.fill((255, 255, 255))
bound = DEFAULT_X
scale = self.screen_w / (bound * 2)
offset = self.screen_w // 2
agent_width = int(0.05 * scale)
gfxdraw.hline(self.surf, 0, self.screen_w, self.screen_h//2, (0, 0, 0))
if type(state["x"]) != float:
x = state["x"][0]
else:
x = state["x"]
# drawing agent
gfxdraw.aacircle(self.surf, int(x* scale )+offset, self.screen_h//2, int(agent_width / 2), (255, 0, 0))
gfxdraw.filled_circle(self.surf, int(x* scale)+offset, self.screen_h//2, int(agent_width / 2), (255, 0, 0))
# drawing center
gfxdraw.aacircle(self.surf, offset, self.screen_h//2, int(agent_width / 4), (0, 0, 0))
#self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0))
self.clock.tick(render_fps)
pygame.display.flip()
def close(self):
if self.screen is not None:
pygame.display.quit()
pygame.quit()
def get_obskeys(self):
return ["x"]
if __name__ == "__main__":
# Tests on environnement
env = SimpleEnv()
check_env_specs(env)
print("\n* observation_spec:", env.observation_spec)
print("\n* action_spec:", env.action_spec)
print("\n* reward_spec:", env.reward_spec)
print("\n* random action: \n", env.action_spec.rand())
td = env.reset()
print("\n* reset",td)
td = env.rand_step(td)
print("\n* random step tensordict", td)
rollout = env.rollout(3)
print("\n* rollout of three steps:", rollout)
print("\n* Shape of the rollout TensorDict:", rollout.batch_size)
batch_size = 10
td = env.reset(env.gen_params(batch_size=[batch_size]))
print("\n* reset (batch size of 10)", td)
env.close()
and a minimal network for training the agent on this minimal env:
import torch
from torch import nn
from collections import defaultdict
import matplotlib.pyplot as plt
from tensordict.nn import TensorDictModule
from torchrl.envs import CatTensors,TransformedEnv,UnsqueezeTransform
from torchrl.envs.transforms.transforms import _apply_to_composite
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
import Simple1D1AgentEnv
env = Simple1D1AgentEnv.SimpleEnv()
obs_keys = env.get_obskeys()
env = TransformedEnv(
env,
UnsqueezeTransform(
unsqueeze_dim=-1,
in_keys=obs_keys,
in_keys_inv=obs_keys,
),
)
cat_transform = CatTensors(
in_keys=obs_keys, dim=-1, out_key="observation", del_keys=False
)
env.append_transform(cat_transform)
torch.manual_seed(0)
env.set_seed(0)
E = 64
net = nn.Sequential(
nn.LazyLinear(E),
nn.Tanh(),
nn.LazyLinear(E),
nn.Tanh(),
nn.LazyLinear(E),
nn.Tanh(),
nn.LazyLinear(1),
)
policy = TensorDictModule(
net,
in_keys=["observation"],
out_keys=["action"],
)
optim = torch.optim.Adam(policy.parameters(), lr=2e-3)
batch_size = 32
frames = 10_000
iterations = frames // batch_size
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, frames)
logs = defaultdict(list)
for i in range(iterations):
init_td = env.reset(env.gen_params(batch_size=[batch_size]))
rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)
traj_return = rollout["next", "reward"].mean()
(-traj_return).backward()
gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
optim.step()
optim.zero_grad()
progres = 100*i/iterations
print(
f"{progres: 3.1f}%, reward: {traj_return: 4.4f}, "
f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
)
logs["return"].append(traj_return.item())
logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
logs["grad_norm"].append(gn)
scheduler.step()
env.close()
def plot():
import matplotlib
from matplotlib import pyplot as plt
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.plot(logs["return"])
plt.title("returns")
plt.xlabel("iteration")
plt.subplot(1, 3, 2)
plt.plot(logs["last_reward"])
plt.title("last reward")
plt.xlabel("iteration")
plt.subplot(1, 3, 3)
plt.plot(logs["grad_norm"])
plt.title("grad_norm")
plt.xlabel("iteration")
plt.show()
plot()
Thanks a lot for tutorial
https://pytorch.org/rl/tutorials/pendulum.html