'CUDA out of memory' when using a GPU services for reinforcement learning in Torch rpc tutorial

I followed this tutorial to implement reinforcement learning with RPC on Torch.

Currently, I use one trainer process and one observer process.

The trainer process creating the model, and the observer process calls the model forward using RPC.

After adding the specified GPU device for the model as shown in the original tutorial, I encountered a “cuda out of memory” issue.

To simplify reproduction, I removed some of the original code from the tutorial.

from itertools import count

import gym
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

AGENT_NAME = "agent_{}"
OBSERVER_NAME = "obs_{}_for_{}"

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)


class Observer:

    def __init__(self):
        self.id = rpc.get_worker_info().id
        self.env = gym.make('CartPole-v1')
        # self.env.seed(args.seed)

    def run_episode(self, agent_rref):
        state, ep_reward = self.env.reset()[0], 0
        for _ in range(10000):
            # send the state to the agent to get an action
            action = agent_rref.rpc_sync().select_action(self.id, state)
            # apply the action to the environment, and get the reward
            state, reward, terminated, truncated, _ = self.env.step(action)
            # report the reward to the agent for training purpose
            agent_rref.rpc_sync().report_reward(self.id, reward)
            # finishes after the number of self.env._max_episode_steps
            if terminated or truncated:
                break
        torch.cuda.empty_cache()


class Agent:
    def __init__(self, rank, observer_size_pre_trainer, infos):
        self.ob_rrefs = []
        self.agent_rref = RRef(self)
        self.rewards = {}
        self.saved_log_probs = {}
        self.device_id = rank % torch.cuda.device_count()
        self.policy = Policy().to(self.device_id)
        self.rank = rank
        for ob_rank in range(0, observer_size_pre_trainer):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank, rank))
            self.ob_rrefs.append(remote(ob_info, Observer))
            self.rewards[ob_info.id] = []
            self.saved_log_probs[ob_info.id] = []

    def select_action(self, ob_id, state):
        state: torch.Tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device_id)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[ob_id].append(m.log_prob(action))
        result = action.item()

        del action, m, state, probs
        return result

    def report_reward(self, ob_id, reward):
        self.rewards[ob_id].append(reward)

    def run_episode(self):
        futs = []
        for ob_rref in self.ob_rrefs:
            # make async RPC to kick off an episode on all observers
            futs.append(
                rpc_async(
                    ob_rref.owner(),
                    ob_rref.rpc_sync().run_episode,
                    args=(self.agent_rref,)
                )
            )

        # wait until all obervers have finished this episode
        for fut in futs:
            fut.wait()

    def finish_episode(self):
        # joins probs and rewards from different observers into lists
        R, probs, rewards = 0, [], []
        for ob_id in self.rewards:
            probs.extend(self.saved_log_probs[ob_id])
            rewards.extend(self.rewards[ob_id])

        # clear saved probs and rewards
        for ob_id in self.rewards:
            self.rewards[ob_id] = []
            self.saved_log_probs[ob_id] = []
        del probs, rewards
        return 0


def get_observer_name(rank, trainer_size):
    observer_rank = (rank - trainer_size) // trainer_size
    trainer_rank = rank % trainer_size
    return OBSERVER_NAME.format(observer_rank, trainer_rank)


def run_worker(rank, trainer_size, observer_size_pre_trainer, infos):
    rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
        init_method='tcp://localhost:29500',
        num_worker_threads=1024,
    )
    world_size = observer_size_pre_trainer * trainer_size + trainer_size
    if rank < trainer_size:
        dist.init_process_group(
            rank=rank, world_size=trainer_size, init_method="tcp://localhost:29501"
        )
        name = AGENT_NAME.format(rank)
        print(f"{name} started")
        rpc.init_rpc(name, rank=rank, world_size=world_size,
                     rpc_backend_options=rpc_backend_options)

        agent = Agent(rank, observer_size_pre_trainer, infos)
        for i_episode in count(1):
            agent.run_episode()
            agent.finish_episode()
            print(f"episode : {i_episode}, mem_used: {torch.cuda.memory_allocated(agent.device_id) / 1024 / 1024:.2f}Mb")
    else:
        observer = get_observer_name(rank, trainer_size)
        print(f"{observer} started")
        # other ranks are the observer
        rpc.init_rpc(observer, rank=rank, world_size=world_size,
                     rpc_backend_options=rpc_backend_options)
        # observers passively waiting for instructions from the agent

    # block until all rpcs finish, and shutdown the RPC instance
    rpc.shutdown()


def main():
    mp.spawn(
        run_worker,
        args=(1, 1, {}),
        nprocs=2,
        join=True
    )


if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn')
    main()

And I got result

/home/lu/PycharmProjects/tetris/venv/bin/python /home/lu/PycharmProjects/tetris/test_error.py 
obs_0_for_0 started
WARNING: Logging before InitGoogleLogging() is written to STDERR
I20230807 23:22:56.714407 262881 ProcessGroupNCCL.cpp:665] [Rank 0] ProcessGroupNCCL initialized with following options:
NCCL_ASYNC_ERROR_HANDLING: 0
NCCL_DESYNC_DEBUG: 0
NCCL_BLOCKING_WAIT: 0
TIMEOUT(ms): 1800000
USE_HIGH_PRIORITY_STREAM: 0
I20230807 23:22:56.714471 262980 ProcessGroupNCCL.cpp:842] [Rank 0] NCCL watchdog thread started!
agent_0 started
/home/lu/PycharmProjects/tetris/venv/lib/python3.11/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
episode : 1, mem_used: 146.25Mb
episode : 2, mem_used: 251.88Mb
episode : 3, mem_used: 438.75Mb
episode : 4, mem_used: 682.50Mb
.....

episode : 44, mem_used: 4834.38Mb

At:
  /usr/lib/python3.11/site-packages/torch/distributed/rpc/internal.py(234): _handle_exception
')
Traceback (most recent call last):
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/internal.py", line 207, in _run_function
    result = python_udf.func(*python_udf.args, **python_udf.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/rref_proxy.py", line 42, in _invoke_rpc
    return _rref_type_cont(rref_fut)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/rref_proxy.py", line 31, in _rref_type_cont
    return rpc_api(
           ^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/api.py", line 809, in rpc_sync
    return fut.wait()
           ^^^^^^^^^^
RuntimeError: RuntimeError: On WorkerInfo(id=1, name=obs_0_for_0):
RuntimeError('OutOfMemoryError: On WorkerInfo(id=0, name=agent_0):
OutOfMemoryError('CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.75 GiB total capacity; 4.73 GiB already allocated; 10.88 MiB free; 5.82 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF')
Traceback (most recent call last):
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/internal.py", line 207, in _run_function
    result = python_udf.func(*python_udf.args, **python_udf.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/distributed/rpc/rref_proxy.py", line 11, in _local_invoke
    return getattr(rref.local_value(), func_name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lu/PycharmProjects/tetris/test_error.py", line 71, in select_action
    probs = self.policy(state)
            ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lu/PycharmProjects/tetris/test_error.py", line 25, in forward
    x = self.affine1(x)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.75 GiB total capacity; 4.73 GiB already allocated; 10.88 MiB free; 5.82 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF



Process finished with exit code 1

And I use manjaro(Arch linux OS), python 11 and torch 2.0.1

Based on this error is seems as if another process might be using device memory from the same GPU. Could you check. if this might be the case?

The problem is that the memory usage of my model is supposed to be small, expected to be less than 100MB, but in reality, it rapidly increases to 4GB as the number of episodes grows.

        self.affine1 = nn. Linear (4, 128)
        self.dropout = nn. Dropout(p=0.6)
        self.affine2 = nn. Linear (128, 2)

And the memory printed:

episode : 1, mem_used: 260.00Mb
episode : 2, mem_used: 430.63Mb
episode : 3, mem_used: 528.13Mb
episode : 4, mem_used: 674.38Mb
episode : 5, mem_used: 812.50Mb
episode : 6, mem_used: 910.00Mb
episode : 7, mem_used: 1145.63Mb
episode : 8, mem_used: 1235.00Mb
episode : 9, mem_used: 1324.38Mb
episode : 10, mem_used: 1543.75Mb
episode : 11, mem_used: 1673.75Mb
episode : 12, mem_used: 1868.75Mb
episode : 13, mem_used: 1990.63Mb
episode : 14, mem_used: 2169.38Mb
episode : 15, mem_used: 2437.50Mb
episode : 16, mem_used: 2600.00Mb
episode : 17, mem_used: 2681.25Mb
episode : 18, mem_used: 2811.25Mb
episode : 19, mem_used: 2835.63Mb
episode : 20, mem_used: 2851.88Mb
episode : 21, mem_used: 2851.88Mb
episode : 22, mem_used: 2933.13Mb
episode : 23, mem_used: 2933.13Mb
episode : 24, mem_used: 3022.50Mb
episode : 25, mem_used: 3046.88Mb
episode : 26, mem_used: 3152.50Mb
episode : 27, mem_used: 3233.75Mb
episode : 28, mem_used: 3355.63Mb
episode : 29, mem_used: 3469.38Mb
episode : 30, mem_used: 3843.13Mb
episode : 31, mem_used: 3932.50Mb
episode : 32, mem_used: 4030.00Mb
episode : 33, mem_used: 4127.50Mb
episode : 34, mem_used: 4208.75Mb
episode : 35, mem_used: 4346.88Mb
episode : 36, mem_used: 4387.50Mb
episode : 37, mem_used: 4387.50Mb
episode : 38, mem_used: 4403.75Mb
episode : 39, mem_used: 4460.63Mb
episode : 40, mem_used: 4460.63Mb
episode : 41, mem_used: 4517.50Mb
episode : 42, mem_used: 4566.25Mb
episode : 43, mem_used: 4606.88Mb
episode : 44, mem_used: 4647.50Mb
episode : 45, mem_used: 4647.50Mb
episode : 46, mem_used: 4655.63Mb
episode : 47, mem_used: 4696.25Mb
episode : 48, mem_used: 4728.75Mb

Since your memory usage increases in each epoch, check if you are storing tensors (which might even be attached to the computation graph) in a list, dict, or any other container. PyTorch won’t be able to delete these even if you are explicitly calling del tensor since their reference is still stored. Inplace operations with another tensor, such as loss += batch_loss, will have the same effect.

I delete array append. However, the issue still persists., and use _record_memory_history to position memory allocate.

import gc
import os.path
from itertools import count

import gym
import numpy as np
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

AGENT_NAME = "agent_{}"
OBSERVER_NAME = "obs_{}_for_{}"
torch.cuda.memory._record_memory_history(True)

def release_memory(rank):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print(torch.cuda.memory_summary(rank))
    from pickle import dump
    snapshot = torch.cuda.memory._snapshot()
    if os.path.exists("snapshot.pickle"):
        if os.path.exists("snapshot.pickle.0"):
            os.remove("snapshot.pickle.0")
        os.rename("snapshot.pickle", "snapshot.pickle.0")
    dump(snapshot, open('snapshot.pickle', 'wb'))


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 2)
        # self.dropout = nn.Dropout(p=0.6)
        # self.affine2 = nn.Linear(128, 2)

    def forward(self, x):
        print(x.size())
        x1 = self.affine1(x)
        # x2 = self.dropout(x1)
        # x3 = F.relu(x2)
        # action_scores = self.affine2(x3)
        # result = F.softmax(action_scores, dim=1)
        # del action_scores, x1, x2, x3
        # return result
        result = F.softmax(x1, dim=1)
        del x1
        return result


class Observer:

    def __init__(self):
        self.id = rpc.get_worker_info().id

    def run_episode(self, agent_rref):
        state = np.random.random((4,))
        agent_rref.rpc_sync().select_action(self.id, state)


class Agent:
    def __init__(self, rank, observer_size_pre_trainer, infos):
        ob_info = rpc.get_worker_info(OBSERVER_NAME.format(0, rank))
        self.ob_rref = remote(ob_info, Observer)
        self.agent_rref = RRef(self)
        self.device_id = rank % torch.cuda.device_count()
        self.policy = Policy().to(self.device_id)
        self.rank = rank

    def select_action(self, ob_id, state):
        release_memory(self.rank)
        s: torch.Tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device_id)
        with torch.no_grad():
            probs = self.policy(s)
        # m = Categorical(probs)
        # action = m.sample()
        # log_prob = m.log_prob(action)
        # self.saved_log_probs[ob_id].append(log_prob.item())
        # result = action.item()
        #
        # del action, m, state, probs, log_prob
        del probs
        release_memory(self.device_id)
        return 1

    def run_episode(self):
        rpc_async(
            self.ob_rref.owner(),
            self.ob_rref.rpc_sync().run_episode,
            args=(self.agent_rref,)
        ).wait()


def run_worker(rank, world_size):
    rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
        init_method='tcp://localhost:29500',
        num_worker_threads=1024,
    )
    if rank == 0:
        name = AGENT_NAME.format(rank)
        print(f"{name} started")
        rpc.init_rpc(name, rank=rank, world_size=world_size,
                     rpc_backend_options=rpc_backend_options)

        agent = Agent(rank, 1, {})

        for i_episode in range(100):
            agent.run_episode()
            print(f"episode : {i_episode}, mem_used: {torch.cuda.memory_allocated(agent.device_id) / 1024 / 1024:.2f}Mb")

    else:
        observer = OBSERVER_NAME.format(0, 0)
        print(f"{observer} started")
        rpc.init_rpc(observer, rank=rank, world_size=world_size,
                     rpc_backend_options=rpc_backend_options)

    rpc.shutdown()


def main():
    mp.spawn(run_worker, args=(2,), nprocs=2, join=True)


if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn')
    main()

Found that the main reason for the increasing memory usage is the linear operation of model in theselect_action function during the forward pass.
But I don’t know Why not auto release the memory