Training converges on cpu but never on gpu

Hi,
for the script below my training always converges on CPU and solves the environment. However if I switch DEVICE to cuda it actually never converges.

my python env: torch 2.5.1, torchrl 0.6.0, tensordict 0.6.2

What I tried so far:

  • I also tried running the code on the nightly version of torchrl and the behaviour stayed the same.
  • I verified that my gpu is working correctly, I also was able to run the torchrl PPO Tutorial successfully on gpu.
  • i ran the scripts on my workplaces gpu server as well as on a separate windows laptop using torchrl-nightly and the behaviour still persists

This leads me to conclude that I have some hidden bug in my script or did not use torchrl correctly. I would be really thankful for any pointer on what to look at.

Note: I also posted this on the torchrl github under discussion, but hope that it maybe finds more attention here.

import logging
import time

import torch.nn as nn
import torch.optim as optim

from tensordict.nn import TensorDictSequential as TDSeq

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import GymEnv, TransformedEnv, StepCounter, RewardSum, Compose
from torchrl.modules import EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate


logging.basicConfig(level=logging.INFO)
my_logger = logging.getLogger(__name__)


ENV_NAME = "CartPole-v1"
DEVICE = "cuda"

INIT_RND_STEPS = 5_120
FRAMES_PER_BATCH = 128
BUFFER_SIZE = 100_000

GAMMA = 0.98
OPTIM_STEPS = 10
BATCH_SIZE = 128

SOFTU_EPS = 0.99
LR = 0.02


class Net(nn.Module):
    def __init__(self, obs_size: int, n_actions: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        orig_shape_unbatched = len(x.shape) == 1
        if orig_shape_unbatched:
            x = x.unsqueeze(0)

        out = self.net(x)

        if orig_shape_unbatched:
            out = out.squeeze(0)
        return out


def make_env(env_name: str):
    return TransformedEnv(
        GymEnv(env_name),
        Compose(
            StepCounter(),
            RewardSum()
        )
    )


if __name__ == "__main__":
    env = make_env(ENV_NAME)
    n_obs = env.observation_spec["observation"].shape[-1]
    n_act = env.action_spec.shape[-1]

    net = Net(n_obs, n_act).to(device=DEVICE)
    agent = QValueActor(net, spec=env.action_spec)
    policy_explore = EGreedyModule(env.action_spec)
    agent_explore = TDSeq(agent, policy_explore)

    collector = SyncDataCollector(
        env,
        agent_explore,
        frames_per_batch=FRAMES_PER_BATCH,
        init_random_frames=INIT_RND_STEPS,
        device=DEVICE,
    )
    exp_buffer = ReplayBuffer(storage=LazyTensorStorage(BUFFER_SIZE))

    loss = DQNLoss(value_network=agent, action_space=env.action_spec, delay_value=True)
    loss.make_value_estimator(gamma=GAMMA)
    target_updater = SoftUpdate(loss, eps=SOFTU_EPS)
    optimizer = optim.Adam(loss.parameters(), lr=LR)

    total_count = 0
    total_episodes = 0
    t0 = time.time()
    for i, data in enumerate(collector):
        exp_buffer.extend(data)
        max_length = exp_buffer["next", "step_count"].max()
        max_reward = exp_buffer["next", "episode_reward"].max()
        if len(exp_buffer) > INIT_RND_STEPS:
            for _ in range(OPTIM_STEPS):
                optimizer.zero_grad()
                sample = exp_buffer.sample(batch_size=BATCH_SIZE)
                sample = sample.to(DEVICE)

                loss_vals = loss(sample)
                loss_vals["loss"].backward()
                optimizer.step()

                agent_explore[1].step(data.numel())
                target_updater.step()

                total_count += data.numel()
                total_episodes += data["next", "done"].sum()

            if i % 10 == 0:
                my_logger.info(f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}.")

        if max_length > 200:
            t1 = time.time()
            my_logger.info(f"SOLVED in {t1 - t0}s!! MaxLen: {max_length}!")
            my_logger.info(f"With {max_reward} Reward!")
            my_logger.info(f"In {total_episodes} Episodes!")
            break

There seems to be a bug with the data transfers in the collector, I’ll give it a shot next week and keep you posted!

Thank you, really appreciate it!

Hello!
I solved it in this pr
The issue was that your policy wasn’t completely on CUDA. Part of it was still on CPU (namely the exploration module). Then when you ask the collector to run it, it sends everything to cuda. Here there was an issue which made the collector lose track of the origin tensor (this is what the PR is solving).
You need to correct your script a bit though:

  • you could do agent_explore = agent_explore.to(device), in which case you don’t need the PR
  • If you don’t do that, use the PR (nightly build) and add a call to collector.update_policy_weights_() just after your model update
[...]
                total_count += data.numel()
                total_episodes += data["next", "done"].sum()

            if i % 10 == 0:
                my_logger.info(f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}.")
        collector.update_policy_weights_()

That will copy the CPU buffers on GPU.

I also spotted a bug when you have partial devices (ie one for the policy and one for the env) which I fixed in [BugFix] Fix device transfer for collectors with init_random_frames mixed devices by vmoens · Pull Request #2704 · pytorch/rl · GitHub

With these changes I solve the whole thing in a similar number of iterations in every case.

LMK if that works!

1 Like

Thank you for looking into this issue!

  • I can confirm that agent_explore = agent_explore.to(device) fixes the problem.

Edit: I tested on the lastest nightly not on the latest github changes thats why i had behaviour below. In actuality the proposed fix works!

  • adding collector.update_policy_weights_() fixes the issue only if device in collector is set to cpu, as in below
collector = SyncDataCollector(
        env,
        agent_explore,
        frames_per_batch=FRAMES_PER_BATCH,
        init_random_frames=INIT_RND_STEPS,
        device="cpu",
    )

if device="cuda" training still doesn’t converge.

For me this is fine and I would have been happy with just pushing the policy on to the gpu as fix, but maybe one would want to enforce that if device in collector is set, than policy has to be on the same device.

Again, thank you for looking into this! :slight_smile:

it will only if you pull the latest changes from github!

my bad you are right! I tested on the nightly version of torchrl but not with latest changes from github.

Can confirm, it work!

1 Like