Bug in Torchrl Tutorial PPO Example

Hi community,

I am following the torchrl tutorial PPO example to learn how torchrl works.

However, the Loss function part raise error in the tutorial. When I run it on my side (torchrl v0.1.1 + pytorch 2.0) it raise error:

Traceback (most recent call last)
Cell In[12], line 1
----> 1 advantage_module = GAE(
      2     gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
      3 )
      4 loss_module = ClipPPOLoss(
      5     actor=policy_module,
      6     critic=value_module,
   (...)
     15     loss_critic_type="smooth_l1",
     16 )
     18 optim = torch.optim.Adam(loss_module.parameters(), lr)

File ~/miniconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl/objectives/value/advantages.py:779, in GAE.__init__(self, gamma, lmbda, value_network, average_gae, differentiable, vectorized, advantage_key, value_target_key, value_key, skip_existing)
    765 def __init__(
    766     self,
    767     *,
   (...)
    777     skip_existing: Optional[bool] = None,
    778 ):
--> 779     super().__init__(
    780         value_network=value_network,
    781         differentiable=differentiable,
    782         advantage_key=advantage_key,
...
    119     )
    121 self.advantage_key = advantage_key
    122 self.value_target_key = value_target_key

KeyError: "value key 'state_value' not found in value network out_keys."

Besides, there is error in the training output part of DQN example.

Is there any example cound run successfully?

Thanks

Thanks for raising this, we’ll issue a fix asap

Thanks. Please let me know when your team fix it.

This is weird, I can execute the code locally on torchrl v0.1.1, torch v2.0.1 and tensordict v0.1.2
Are you using these versions?

I am using torchrl v0.1.1, torch v2.0.0 and tensordict v0.1.2.

Besides, this

print("state_spec:", env.state_spec)

also raise error on my side

AttributeError: 'InvertedDoublePendulumEnv' object has no attribute 'state_spec'

I update pytorch from v2.0.0 to v2.0.1 and the error is still there.

In addition, I have also tryied the examples in github repo.

No one works for me. For example,

$python sac/sac.py env_name="HalfCheetah-v4" env_task="" env_library="gym"


sys:1: UserWarning:
'config' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
/home/hai/miniconda3/envs/torch_rl/lib/python3.9/site-packages/hydra/main.py:94: UserWarning:
'config' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  _run_hydra(
self.log_dir: sac_logging/SAC__c1c1aac0_23_07_10-15_27_52
/home/hai/miniconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or funct
ionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/home/hai/miniconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl/collectors/collectors.py:1182: UserWarning: total_frames (1000000) is not exactly divisible by frames_per_batch (1024).T
his means 448 additional frames will be collected.To silence this message, set the environment variable RL_WARNINGS to False.
  warnings.warn(
Error executing job with overrides: ['env_name=HalfCheetah-v4', 'env_task=', 'env_library=gym']
Traceback (most recent call last):
  File "/home/hai/rl/examples/sac/sac.py", line 173, in main
    recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict())
  File "/home/hai/miniconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Compose:
        Missing key(s) in state_dict: "transforms.1.standard_normal", "transforms.1.loc", "transforms.1.scale".
        Unexpected key(s) in state_dict: "transforms.0.standard_normal", "transforms.0.loc", "transforms.0.scale", "transforms.2.standard_normal", "transforms.2.loc", "transforms.2.scale".

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Can you print what your conda env looks like?

# packages in environment at /home/hai/miniconda3/envs/torch_rl:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
absl-py                   1.4.0                    pypi_0    pypi
ale-py                    0.8.1                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
appdirs                   1.4.4                    pypi_0    pypi
asttokens                 2.2.1              pyhd8ed1ab_0    conda-forge
autorom                   0.4.2                    pypi_0    pypi
autorom-accept-rom-license 0.6.1                    pypi_0    pypi
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                pyhd8ed1ab_3    conda-forge
backports.functools_lru_cache 1.6.5              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl  
brotlipy                  0.7.0           py39h27cfd23_1003  
bzip2                     1.0.8                h7b6447c_0  
ca-certificates           2023.5.7             hbcca054_0    conda-forge
cachetools                5.3.1                    pypi_0    pypi
certifi                   2023.5.7           pyhd8ed1ab_0    conda-forge
cffi                      1.15.1           py39h5eee18b_3  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.4                    pypi_0    pypi
cloudpickle               2.2.1                    pypi_0    pypi
contourpy                 1.1.0                    pypi_0    pypi
cryptography              39.0.1           py39h9ce1e76_2  
cuda                      11.7.1                        0    nvidia
cuda-cccl                 11.7.91                       0    nvidia
cuda-command-line-tools   11.7.1                        0    nvidia
cuda-compiler             11.7.1                        0    nvidia
cuda-cudart               11.7.99                       0    nvidia
cuda-cudart-dev           11.7.99                       0    nvidia
cuda-cuobjdump            11.7.91                       0    nvidia
cuda-cupti                11.7.101                      0    nvidia
cuda-cuxxfilt             11.7.91                       0    nvidia
cuda-demo-suite           12.2.53                       0    nvidia
cuda-documentation        12.2.53                       0    nvidia
cuda-driver-dev           11.7.99                       0    nvidia
cuda-gdb                  12.2.53                       0    nvidia
cuda-libraries            11.7.1                        0    nvidia
cuda-libraries-dev        11.7.1                        0    nvidia
cuda-memcheck             11.8.86                       0    nvidia
cuda-nsight               12.2.53                       0    nvidia
cuda-nsight-compute       12.2.0                        0    nvidia
cuda-nvcc                 11.7.99                       0    nvidia
cuda-nvdisasm             12.2.53                       0    nvidia
cuda-nvml-dev             11.7.91                       0    nvidia
cuda-nvprof               12.2.60                       0    nvidia
cuda-nvprune              11.7.91                       0    nvidia
cuda-nvrtc                11.7.99                       0    nvidia
cuda-nvrtc-dev            11.7.99                       0    nvidia
cuda-nvtx                 11.7.91                       0    nvidia
cuda-nvvp                 12.2.60                       0    nvidia
cuda-runtime              11.7.1                        0    nvidia
cuda-sanitizer-api        12.2.53                       0    nvidia
cuda-toolkit              11.7.1                        0    nvidia
cuda-tools                11.7.1                        0    nvidia
cuda-visual-tools         11.7.1                        0    nvidia
cycler                    0.11.0                   pypi_0    pypi
cython                    0.29.36                  pypi_0    pypi
debugpy                   1.5.1            py39h295c915_0  
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
docker-pycreds            0.4.0                    pypi_0    pypi
entrypoints               0.4                pyhd8ed1ab_0    conda-forge
exceptiongroup            1.1.2                    pypi_0    pypi
executing                 1.2.0              pyhd8ed1ab_0    conda-forge
fasteners                 0.18                     pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.9.0            py39h06a4308_0  
fonttools                 4.40.0                   pypi_0    pypi
freetype                  2.12.1               h4a9f257_0  
gds-tools                 1.7.0.149                     0    nvidia
giflib                    5.2.1                h5eee18b_3  
gitdb                     4.0.10                   pypi_0    pypi
gitpython                 3.1.31                   pypi_0    pypi
glfw                      2.6.2                    pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gmpy2                     2.1.2            py39heeb90bb_0  
gnutls                    3.6.15               he1e5248_0  
google-auth               2.21.0                   pypi_0    pypi
google-auth-oauthlib      1.0.0                    pypi_0    pypi
grpcio                    1.56.0                   pypi_0    pypi
gym                       0.26.2                   pypi_0    pypi
gym-notices               0.0.8                    pypi_0    pypi
hydra-core                1.3.2                    pypi_0    pypi
hydra-submitit-launcher   1.2.0                    pypi_0    pypi
idna                      3.4              py39h06a4308_0  
imageio                   2.31.1                   pypi_0    pypi
importlib-metadata        6.8.0                    pypi_0    pypi
importlib-resources       6.0.0                    pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
intel-openmp              2023.1.0         hdb19cb5_46305  
ipykernel                 6.15.0             pyh210e3f2_0    conda-forge
ipython                   8.14.0             pyh41d4057_0    conda-forge
jedi                      0.18.2             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.2            py39h06a4308_0  
jpeg                      9e                   h5eee18b_1  
jupyter_client            7.0.6              pyhd8ed1ab_0    conda-forge
jupyter_core              4.12.0           py39hf3d152e_0    conda-forge
kiwisolver                1.4.4                    pypi_0    pypi
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libcublas                 11.11.3.6                     0    nvidia
libcublas-dev             11.11.3.6                     0    nvidia
libcufft                  10.9.0.58                     0    nvidia
libcufft-dev              10.9.0.58                     0    nvidia
libcufile                 1.7.0.149                     0    nvidia
libcufile-dev             1.7.0.149                     0    nvidia
libcurand                 10.3.3.53                     0    nvidia
libcurand-dev             10.3.3.53                     0    nvidia
libcusolver               11.4.1.48                     0    nvidia
libcusolver-dev           11.4.1.48                     0    nvidia
libcusparse               11.7.5.86                     0    nvidia
libcusparse-dev           11.7.5.86                     0    nvidia
libdeflate                1.17                 h5eee18b_0  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libiconv                  1.16                 h7f8727e_2  
libidn2                   2.3.4                h5eee18b_0  
libnpp                    11.8.0.86                     0    nvidia
libnpp-dev                11.8.0.86                     0    nvidia
libnvjpeg                 11.9.0.86                     0    nvidia
libnvjpeg-dev             11.9.0.86                     0    nvidia
libpng                    1.6.39               h5eee18b_0  
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              11.2.0               h1234567_1  
libtasn1                  4.19.0               h5eee18b_0  
libtiff                   4.5.0                h6a678d5_2  
libunistring              0.9.10               h27cfd23_0  
libwebp                   1.2.4                h11a3e52_1  
libwebp-base              1.2.4                h5eee18b_1  
lz4-c                     1.9.4                h6a678d5_0  
markdown                  3.4.3                    pypi_0    pypi
markupsafe                2.1.1            py39h7f8727e_0  
matplotlib                3.7.2                    pypi_0    pypi
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mkl                       2023.1.0         h6d00ec8_46342  
mkl-service               2.4.0            py39h5eee18b_1  
mkl_fft                   1.3.6            py39h417a72b_1  
mkl_random                1.2.2            py39h417a72b_1  
mpc                       1.1.0                h10f8cd9_1  
mpfr                      4.0.2                hb69a4c5_1  
mpmath                    1.2.1            py39h06a4308_0  
mujoco                    2.3.6                    pypi_0    pypi
mujoco-py                 2.1.2.14                 pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.5.6              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1  
networkx                  2.8.4            py39h06a4308_1  
nsight-compute            2023.2.0.16                   0    nvidia
numpy                     1.25.0           py39h5f9d8c6_0  
numpy-base                1.25.0           py39hb5e798b_0  
oauthlib                  3.2.2                    pypi_0    pypi
omegaconf                 2.3.0                    pypi_0    pypi
openh264                  2.1.1                h4ff587b_0  
openssl                   3.0.9                h7f8727e_0  
packaging                 23.1               pyhd8ed1ab_0    conda-forge
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
pathtools                 0.1.2                    pypi_0    pypi
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    9.4.0            py39h6a678d5_0  
pip                       23.1.2           py39h06a4308_0  
pluggy                    1.2.0                    pypi_0    pypi
prompt-toolkit            3.0.39             pyha770c72_0    conda-forge
prompt_toolkit            3.0.39               hd8ed1ab_0    conda-forge
protobuf                  4.23.4                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pyasn1                    0.5.0                    pypi_0    pypi
pyasn1-modules            0.3.0                    pypi_0    pypi
pycparser                 2.21               pyhd3eb1b0_0  
pygame                    2.5.0                    pypi_0    pypi
pygments                  2.15.1             pyhd8ed1ab_0    conda-forge
pyopengl                  3.1.7                    pypi_0    pypi
pyopenssl                 23.0.0           py39h06a4308_0  
pyparsing                 3.0.9                    pypi_0    pypi
pysocks                   1.7.1            py39h06a4308_0  
pytest                    7.4.0                    pypi_0    pypi
pytest-instafail          0.5.0                    pypi_0    pypi
python                    3.9.17               h955ad1f_0  
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python_abi                3.9                      2_cp39    conda-forge
pytorch                   2.0.1           py3.9_cuda11.7_cudnn8.5.0_0    pytorch
pytorch-cuda              11.7                 h67b0de4_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0                      pypi_0    pypi
pyzmq                     19.0.2           py39hb69f2a1_2    conda-forge
readline                  8.2                  h5eee18b_0  
requests                  2.29.0           py39h06a4308_0  
requests-oauthlib         1.3.1                    pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
sentry-sdk                1.27.1                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                67.8.0           py39h06a4308_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
smmap                     5.0.0                    pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
submitit                  1.4.5                    pypi_0    pypi
sympy                     1.11.1           py39h06a4308_0  
tbb                       2021.8.0             hdb19cb5_0  
tensorboard               2.13.0                   pypi_0    pypi
tensorboard-data-server   0.7.1                    pypi_0    pypi
tensordict                0.1.2                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tomli                     2.0.1                    pypi_0    pypi
torchaudio                2.0.2                py39_cu117    pytorch
torchrl                   0.1.1                    pypi_0    pypi
torchtriton               2.0.0                      py39    pytorch
torchvision               0.15.2               py39_cu117    pytorch
tornado                   6.1              py39hb9d737c_3    conda-forge
tqdm                      4.65.0                   pypi_0    pypi
traitlets                 5.9.0              pyhd8ed1ab_0    conda-forge
typing_extensions         4.6.3            py39h06a4308_0  
tzdata                    2023c                h04d1e81_0  
urllib3                   1.26.16          py39h06a4308_0  
wandb                     0.15.5                   pypi_0    pypi
wcwidth                   0.2.6              pyhd8ed1ab_0    conda-forge
werkzeug                  2.3.6                    pypi_0    pypi
wheel                     0.38.4           py39h06a4308_0  
xz                        5.4.2                h5eee18b_0  
zeromq                    4.3.4                h9c3ff4c_1    conda-forge
zipp                      3.15.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0  
zstd                      1.5.5                hc292b87_0  

I might have an explanation:
I think (but could be wrong) that you cloned torchrl, and you’re executing the examples on main with torchrl 0.1.1, but the main branch of torchrl corresponds to 0.2.0dev (the next release).
So if i’m right, either checkout v0.1.1 on your github clone or use the nightly release.

I checkouted at tag v0.1.1, not at main branch.

$ git status
HEAD detached at v0.1.1

Could it be that you’re executing the code within path/to/torchrl and that python is struggling between importing torchrl from your conda env and the local folder named torchrl?

No, I am running coding_ppo.ipynb located at home folder (instead of rl folder). I export the .ipynb to .py and remove the rl repo folder. It still raise the same error. would you mind testing it on your side?

# %%
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    StepCounter,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm

# %% [markdown]
# ### Hyper parameters

# %%
# training parameters
device = "cuda"
num_cells = 256
lr = 3e-4
max_grad_norm = 1.0

# data collection parameters
frame_skip =1
frames_per_batch = 1000 // frame_skip
# For a complete training, bring the number of frames up to 1M
total_frames = 100000 // frame_skip

# PPO parameters
sub_batch_size = 64
num_epochs = 10
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

# %% [markdown]
# ### Environment Define

# %%
base_env = GymEnv("InvertedDoublePendulum-v4", device=device, frame_skip=frame_skip)
env = TransformedEnv(
    base_env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(
            in_keys=["observation"],
        ),
        StepCounter(),
    ),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

# %%
print("normalization constant shape:", env.transform[0].loc.shape)
print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("done_spec:", env.done_spec)
print("action_spec:", env.action_spec)
# print("state_spec:", env.state_spec)
check_env_specs(env)

# %% [markdown]
# ### PPO Policy

# %%
actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
    NormalParamExtractor(),
)
policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
policy_module = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": env.action_spec.space.minimum,
        "max": env.action_spec.space.maximum,
    },
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)
print("Running policy:", policy_module(env.reset()))

# %% [markdown]
# ### Value Network

# %%
value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
    out_keys=[]
)

print("Running value:", value_module(env.reset()))

# %% [markdown]
# ### Data collector and Replay buffer

# %%
collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

# %% [markdown]
# ### Loss function

# %%
advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)
loss_module = ClipPPOLoss(
    actor=policy_module,
    critic=value_module,
    advantage_key="advantage",
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    value_target_key=advantage_module.value_target_key,
    critic_coef=1.0,
    gamma=0.99,
    loss_critic_type="smooth_l1",
)

Why is the out_keys empty?
That seems like the most obvious explanation for the error message above: I think if you comment out that line things should work ok

1 Like

hmm… This is the issue.

When I was reading the tutorial, I though the ValueOperator here should be equal/similar to TensorDictModule. Then I try to replace ValueOperator to TensorDictModule and add ‘out_keys=[]’ (it should be out_keys = None?). I find it works by

print("Running value:", value_module(env.reset()))

with both ValueOperator and TensorDictModule.

Thanks for you time for my noob bug here.

Best,

ValueOperator automatically computes the out_keys for you. If you’re using a regular TDModule, you can pass out_keys=["state_value"] or whichever other name you please (provided you tell GAE and PPO where to find the value)