How to make the replay buffer more efficient?

  1. use index to pick samples as opposed to directly sampling on the list:
  2. save on tensor creation, do it by batch directly, and if your data is numpy use torch.from_numpy:
    def sample(self, batch_size):
        batch_idxs = np.random.randint(len(self), size=batch_size)
        batches = list(zip(*self.mem[batch_idxs]))
        return [torch.from_numpy(batch) for batch in batches]
  1. If you want to save on the Tensor creation cost, store your data directly as tensors, then sample directly without torch.from_numpy(batch).