No, I am running coding_ppo.ipynb
located at home folder (instead of rl folder). I export the .ipynb to .py and remove the rl repo folder. It still raise the same error. would you mind testing it on your side?
# %%
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
Compose,
DoubleToFloat,
ObservationNorm,
StepCounter,
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm
# %% [markdown]
# ### Hyper parameters
# %%
# training parameters
device = "cuda"
num_cells = 256
lr = 3e-4
max_grad_norm = 1.0
# data collection parameters
frame_skip =1
frames_per_batch = 1000 // frame_skip
# For a complete training, bring the number of frames up to 1M
total_frames = 100000 // frame_skip
# PPO parameters
sub_batch_size = 64
num_epochs = 10
clip_epsilon = (
0.2 # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4
# %% [markdown]
# ### Environment Define
# %%
base_env = GymEnv("InvertedDoublePendulum-v4", device=device, frame_skip=frame_skip)
env = TransformedEnv(
base_env,
Compose(
# normalize observations
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(
in_keys=["observation"],
),
StepCounter(),
),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
# %%
print("normalization constant shape:", env.transform[0].loc.shape)
print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("done_spec:", env.done_spec)
print("action_spec:", env.action_spec)
# print("state_spec:", env.state_spec)
check_env_specs(env)
# %% [markdown]
# ### PPO Policy
# %%
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
NormalParamExtractor(),
)
policy_module = TensorDictModule(
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.action_spec.space.minimum,
"max": env.action_spec.space.maximum,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
)
print("Running policy:", policy_module(env.reset()))
# %% [markdown]
# ### Value Network
# %%
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
out_keys=[]
)
print("Running value:", value_module(env.reset()))
# %% [markdown]
# ### Data collector and Replay buffer
# %%
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
# %% [markdown]
# ### Loss function
# %%
advantage_module = GAE(
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
advantage_key="advantage",
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
# these keys match by default but we set this for completeness
value_target_key=advantage_module.value_target_key,
critic_coef=1.0,
gamma=0.99,
loss_critic_type="smooth_l1",
)