I have implement a DQN model via pytorch, but I find that it runs slow during training, and the main reason is that sampling from replay buffer takes up almost all the time.
Here are my implementation of replay buffer.
class DQNBuffer:
def __init__(self, maxlen=100000, device=None):
self.mem = deque(maxlen=maxlen)
self.maxlen = maxlen
self.device = device
def store(self, s, a, r, s_, a_, d):
self.mem.append([s, a, r, s_, a_, d])
def sample(self, batch_size):
bat = random.sample(self.mem, batch_size)
batch = list(zip(*bat))
data = []
for i in range(len(batch)):
data.append(T.as_tensor(batch[i], dtype=T.float32, device=self.device))
return data
def __len__(self):
return len(self.mem)
I compute the consumption of time when training one batch of samples. Here is the code:
import timeit
def train_batch(agent: DQNAgent, batch_size, gamma):
agent.gstep += 1
agent.optimizer.zero_grad()
print("==================================")
A = timeit.default_timer()
s, a, r, s_, a_, d = agent.buf.sample(batch_size)
print(timeit.default_timer() - A)
B = timeit.default_timer()
qhat = agent.tnet(s_, a_).reshape(-1, ).detach()
target = r + (1.0-d)*gamma*qhat
loss = F.mse_loss(agent.qnet(s, a).reshape(-1, ), target)
loss.backward()
agent.optimizer.step()
print(timeit.default_timer() - B)
print("==================================")
return float(loss)
some of the output ( 4 batches ):
==================================
0.11982669999999995
0.011224900000000204
==================================
==================================
0.12090409999999974
0.011249399999999632
==================================
==================================
0.12087029999999999
0.01114000000000015
==================================
==================================
0.11867099999999997
0.011347199999999447
==================================
agent.buf.sample(batch_size)
takes over 90% of the time – it’s too slow.
How to optimize the sample
operation?