Hi!
I’m reproducing the results of drqv2, but I found something like memory leak during training. After eliminating some factors, I found that the most likely factor causing this phenomenon is the IterableDataset
.
Here is the detail of the ReplayBuffer:
class ReplayBuffer(IterableDataset):
def __init__(self, replay_dir, max_size, num_workers, nstep, discount,
fetch_every, save_snapshot):
self._replay_dir = replay_dir
self._size = 0
self._max_size = max_size
self._num_workers = max(1, num_workers)
self._episode_fns = []
self._episodes = dict()
self._nstep = nstep
self._discount = discount
self._fetch_every = fetch_every
self._samples_since_last_fetch = fetch_every
self._save_snapshot = save_snapshot
def _sample_episode(self):
eps_fn = random.choice(self._episode_fns)
return self._episodes[eps_fn]
def _store_episode(self, eps_fn):
try:
episode = load_episode(eps_fn)
except:
return False
eps_len = episode_len(episode)
while eps_len + self._size > self._max_size:
early_eps_fn = self._episode_fns.pop(0)
early_eps = self._episodes.pop(early_eps_fn)
self._size -= episode_len(early_eps)
early_eps_fn.unlink(missing_ok=True)
self._episode_fns.append(eps_fn)
self._episode_fns.sort()
self._episodes[eps_fn] = episode
self._size += eps_len
if not self._save_snapshot:
eps_fn.unlink(missing_ok=True)
return True
def _try_fetch(self):
if self._samples_since_last_fetch < self._fetch_every:
return
self._samples_since_last_fetch = 0
try:
worker_id = torch.utils.data.get_worker_info().id
except:
worker_id = 0
eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
fetched_size = 0
for eps_fn in eps_fns:
eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
if eps_idx % self._num_workers != worker_id:
continue
if eps_fn in self._episodes.keys():
break
if fetched_size + eps_len > self._max_size:
break
fetched_size += eps_len
if not self._store_episode(eps_fn):
break
def _sample(self):
try:
self._try_fetch()
except:
traceback.print_exc()
self._samples_since_last_fetch += 1
episode = self._sample_episode()
# add +1 for the first dummy transition
idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
obs = episode['observation'][idx - 1]
action = episode['action'][idx]
next_obs = episode['observation'][idx + self._nstep - 1]
reward = np.zeros_like(episode['reward'][idx])
discount = np.ones_like(episode['discount'][idx])
for i in range(self._nstep):
step_reward = episode['reward'][idx + i]
reward += discount * step_reward
discount *= episode['discount'][idx + i] * self._discount
return (obs, action, reward, discount, next_obs)
def __iter__(self):
while True:
The logic of this IterableDataset
is to load and store data dynamically rather than loading all at once at initialization. I found a detailed discussion here, and tried to replace self._episode_fns = []
with np.array
. Meanwhile, I also used deepcopy
to return the value at the bottom of the _sample
function. But all these attempts did not work well, i.e., the memory is occupied almost 30GB+ only with 200,000 training steps.
Is there any better way to optimize this dataset?
Glad to receive any help. Thanks in advance.