How to make the replay buffer more efficient?

I have implement a DQN model via pytorch, but I find that it runs slow during training, and the main reason is that sampling from replay buffer takes up almost all the time.

Here are my implementation of replay buffer.

class DQNBuffer:
    def __init__(self, maxlen=100000, device=None):
        self.mem = deque(maxlen=maxlen)
        self.maxlen = maxlen
        self.device = device

    def store(self, s, a, r, s_, a_, d):
        self.mem.append([s, a, r, s_, a_, d])

    def sample(self, batch_size):
        bat = random.sample(self.mem, batch_size)
        batch = list(zip(*bat))
        data = []
        for i in range(len(batch)):
            data.append(T.as_tensor(batch[i], dtype=T.float32, device=self.device))
        return data

    def __len__(self):
        return len(self.mem)

I compute the consumption of time when training one batch of samples. Here is the code:

import timeit
def train_batch(agent: DQNAgent, batch_size, gamma):
    agent.gstep += 1

    A = timeit.default_timer()
    s, a, r, s_, a_, d = agent.buf.sample(batch_size)
    print(timeit.default_timer() - A)

    B = timeit.default_timer()
    qhat = agent.tnet(s_, a_).reshape(-1, ).detach()
    target = r + (1.0-d)*gamma*qhat
    loss = F.mse_loss(agent.qnet(s, a).reshape(-1, ), target)
    print(timeit.default_timer() - B)

    return float(loss)

some of the output ( 4 batches ):


agent.buf.sample(batch_size) takes over 90% of the time – it’s too slow.
How to optimize the sample operation?

  1. use index to pick samples as opposed to directly sampling on the list:
  2. save on tensor creation, do it by batch directly, and if your data is numpy use torch.from_numpy:
    def sample(self, batch_size):
        batch_idxs = np.random.randint(len(self), size=batch_size)
        batches = list(zip(*self.mem[batch_idxs]))
        return [torch.from_numpy(batch) for batch in batches]
  1. If you want to save on the Tensor creation cost, store your data directly as tensors, then sample directly without torch.from_numpy(batch).

Doing what you did above creates tuples. Does this not mean we will have to convert into tensors again ? I tried doing it but then torch throws error when I pass a batch created above saying that tuple object has no attribute “dim”.