Torchrl, Replay Buffer

Let a = ReplayBuffer(), how I can a.add(tuple of (state, action, reward, next_state, done))?

With the most recent version of the library, you should be able to do this

from import ReplayBuffer, ListStorage
rb = ReplayBuffer(batch_size=4, collate_fn=lambda x: x, storage=ListStorage(10))
s, a, sp, r, d = range(5)
rb.add((s, a, sp, r, d))
s, a, sp, r, d = range(5, 10)
rb.add((s, a, sp, r, d))
s, a, sp, r, d = range(10, 15)
rb.add((s, a, sp, r, d))


(removing the collate function will cause the buffer to call torch.stack on the tuples, which won’t work. The ListStorage is just there to show you how to control the size of the buffer)

which results in

[(5, 6, 7, 8, 9), (0, 1, 2, 3, 4), (0, 1, 2, 3, 4), (5, 6, 7, 8, 9)]

ie 4 samples of your replay buffer.

If you want to make the best of the replay buffer, use TensorDictReplayBuffer with a LazyTensorStorage, which will be much faster to write and sample.

Hope that helps!