I’m trying to make a NN learn to play the CartPole-v1 game from gymnasium, and I followed a similar setup to the one in this tutorial:
Reinforcement Learning (PPO) with TorchRL Tutorial — PyTorch Tutorials 2.5.0+cu124 documentation , only changing a few parameters to make it work with the cart pole game and not the original double pendulum.
I get this error, probably due to my setup of collector:
C:\programming\zoomino 8\blockblastpy\.venv3.12\Lib\site-packages\tensordict\_td.py:2663: UserWarning: An output with one or more elements was resized since it had shape [1000, 2], which does not match the required output shape [1000, 1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\Resize.cpp:35.)
new_dest = torch.stack(
Traceback (most recent call last):
File "C:\programming\zoomino 8\blockblastpy\rl\torchrl\collectors\collectors.py", line 1225, in rollout
result = torch.stack(
^^^^^^^^^^^^
File "C:\programming\zoomino 8\blockblastpy\.venv3.12\Lib\site-packages\tensordict\base.py", line 633, in __torch_function__
return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\programming\zoomino 8\blockblastpy\.venv3.12\Lib\site-packages\tensordict\_torch_func.py", line 666, in _stack
out._stack_onto_(list_of_tensordicts, dim)
File "C:\programming\zoomino 8\blockblastpy\.venv3.12\Lib\site-packages\tensordict\_td.py", line 2663, in _stack_onto_
new_dest = torch.stack(
^^^^^^^^^^^^
RuntimeError: torch.cat(): input types can't be cast to the desired output type Long
Here’s my code:
import torch
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.envs import (Compose, DoubleToFloat, StepCounter,
TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
num_cells = 256 # number of cells in each layer i.e. output dim.
frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000
base_env = GymEnv("CartPole-v1", device=device)
env = TransformedEnv(
base_env,
Compose(
DoubleToFloat(),
StepCounter(),
),
)
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device), # Ensure correct output size
nn.Sigmoid()
)
policy_module = Actor(
module=actor_net,
in_keys=["observation"],
out_keys=["action"],
spec=env.action_spec
)
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
for i, data in enumerate(collector):
print(i)
I’m very new to PyTorch and I’ve tried to understand the cause of the error, but couldn’t.