Memory increasing with IterableDataset

Hi!

I’m reproducing the results of drqv2, but I found something like memory leak during training. After eliminating some factors, I found that the most likely factor causing this phenomenon is the IterableDataset.

Here is the detail of the ReplayBuffer:

class ReplayBuffer(IterableDataset):
    def __init__(self, replay_dir, max_size, num_workers, nstep, discount,
                 fetch_every, save_snapshot):
        self._replay_dir = replay_dir
        self._size = 0
        self._max_size = max_size
        self._num_workers = max(1, num_workers)
        self._episode_fns = []
        self._episodes = dict()
        self._nstep = nstep
        self._discount = discount
        self._fetch_every = fetch_every
        self._samples_since_last_fetch = fetch_every
        self._save_snapshot = save_snapshot

    def _sample_episode(self):
        eps_fn = random.choice(self._episode_fns)
        return self._episodes[eps_fn]

    def _store_episode(self, eps_fn):
        try:
            episode = load_episode(eps_fn)
        except:
            return False
        eps_len = episode_len(episode)
        while eps_len + self._size > self._max_size:
            early_eps_fn = self._episode_fns.pop(0)
            early_eps = self._episodes.pop(early_eps_fn)
            self._size -= episode_len(early_eps)
            early_eps_fn.unlink(missing_ok=True)
        self._episode_fns.append(eps_fn)
        self._episode_fns.sort()
        self._episodes[eps_fn] = episode
        self._size += eps_len

        if not self._save_snapshot:
            eps_fn.unlink(missing_ok=True)
        return True

    def _try_fetch(self):
        if self._samples_since_last_fetch < self._fetch_every:
            return
        self._samples_since_last_fetch = 0
        try:
            worker_id = torch.utils.data.get_worker_info().id
        except:
            worker_id = 0
        eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
        fetched_size = 0
        for eps_fn in eps_fns:
            eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
            if eps_idx % self._num_workers != worker_id:
                continue
            if eps_fn in self._episodes.keys():
                break
            if fetched_size + eps_len > self._max_size:
                break
            fetched_size += eps_len
            if not self._store_episode(eps_fn):
                break

    def _sample(self):
        try:
            self._try_fetch()
        except:
            traceback.print_exc()
        self._samples_since_last_fetch += 1
        episode = self._sample_episode()
        # add +1 for the first dummy transition
        idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
        obs = episode['observation'][idx - 1]
        action = episode['action'][idx]
        next_obs = episode['observation'][idx + self._nstep - 1]
        reward = np.zeros_like(episode['reward'][idx])
        discount = np.ones_like(episode['discount'][idx])
        for i in range(self._nstep):
            step_reward = episode['reward'][idx + i]
            reward += discount * step_reward
            discount *= episode['discount'][idx + i] * self._discount
        return (obs, action, reward, discount, next_obs)

    def __iter__(self):
        while True:

The logic of this IterableDataset is to load and store data dynamically rather than loading all at once at initialization. I found a detailed discussion here, and tried to replace self._episode_fns = [] with np.array. Meanwhile, I also used deepcopy to return the value at the bottom of the _sample function. But all these attempts did not work well, i.e., the memory is occupied almost 30GB+ only with 200,000 training steps.

Is there any better way to optimize this dataset?

Glad to receive any help. Thanks in advance.

I think this post needs more detail. Can you show where you got the result of 30+GB from? Do you get an error which you can show?
Why are you replacing the list with a Numpy array? Maybe try using torch tensors instead? Combining Numpy and Pytorch can sometimes be confusing, leads to errors and converting arrays and tensors back and forth leads to unwanted time overhead.
My advise is trying to debug the dataset in the middle of a dataset iteration. When debugging, look at which variables are in use and how much memory the values are using.
Furthermore I spotted that you implemented the __iter__ method with a while True: loop. Is that the end of the code or did something not copy to your post? The method should return an iterator object, which it currently doesn’t.

Thanks for your time and comment!

I think this post needs more detail. Can you show where you got the result of 30+GB from? Do you get an error which you can show?

I monitor the RAM with free -h. When RAM is almost fully occupied, the following error is displayed

xception in thread Thread-2:
Traceback (most recent call last):
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/threading.py", line 980, in _bootstrap_inner
Error executing job with overrides: []
    self.run()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
    do_one_step()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 355, in rebuild_storage_fd
    fd = df.detach()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/resource_sharer.py", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/connection.py", line 508, in Client
    answer_challenge(c, authkey)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/connection.py", line 757, in answer_challenge
    response = connection.recv_bytes(256)        # reject large message
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer
Traceback (most recent call last):
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/queue.py", line 180, in get
    self.not_empty.wait(remaining)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/threading.py", line 316, in wait
    gotit = waiter.acquire(True, timeout)
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 17761) is killed by signal: Killed. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xzc/kuka_rl_ws/rl_vigen/train_prop.py", line 285, in main
    workspace.train()
  File "/home/xzc/kuka_rl_ws/rl_vigen/train_prop.py", line 244, in train
    metrics = self.agent.update(self.replay_iter, self.global_step)
  File "/home/xzc/kuka_rl_ws/rl_vigen/algos/drqv2_prop.py", line 293, in update
    metrics["actor_loss"] = actor_loss.item()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1284, in _get_data
    success, data = self._try_get_data()
  File "/home/xzc/miniforge3/envs/sb-zoo/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1145, in _try_get_data
    raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e
RuntimeError: DataLoader worker (pid(s) 17761) exited unexpectedly

Why are you replacing the list with a Numpy array? Maybe try using torch tensors instead? Combining Numpy and Pytorch can sometimes be confusing, leads to errors and converting arrays and tensors back and forth leads to unwanted time overhead.

I found the solution here, in which some of the answers mention that replacing list with the numpy array can avoid the memory increase caused by reference counting, but it seems to not work for me.

Furthermore I spotted that you implemented the __iter__ method with a while True: loop. Is that the end of the code or did something not copy to your post? The method should return an iterator object, which it currently doesn’t.

I have attached some other necessary code:

def make_replay_loader(replay_dir, max_size, batch_size, num_workers,
                       save_snapshot, nstep, discount):
    max_size_per_worker = max_size // max(1, num_workers)

    iterable = ReplayBuffer(replay_dir,
                            max_size_per_worker,
                            num_workers,
                            nstep,
                            discount,
                            fetch_every=1000,
                            save_snapshot=save_snapshot)

    loader = torch.utils.data.DataLoader(iterable,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         pin_memory=True,
                                         worker_init_fn=_worker_init_fn)
    return loader

Thanks for the additional information. The error messages probably doesn’t point to the problem, but it shows at what point you have used too much memory.
I still see a problem in your implementation of the __iter__ method. You overwrite the __iter__ method from IterableDataset, which gives an iterable on which you can call the next method the get batches from your dataset. By overwriting it like that and giving your ReplayBuffer to the pytorch Dataloader. The Dataloader will call __iter__ method on the ReplayBuffer, which results in an infinite while loop, probably crashing your dataloader workers.
^ This is my theory of what might be happening, but correct me if I’m wrong. Hopefully your method works again if you remove the overloading of the __iter__ method.

Thanks for your patience. Honestly, the code snippet I provide is not mine, it is the official implementation of drqv2. As you mentioned, I’m also confused about the __iter__ function. But the while loop seems not to cause an infinite loop, the dataloader still returns the data with the given batch size.