Let a = ReplayBuffer(), how I can a.add(tuple of (state, action, reward, next_state, done))?
With the most recent version of the library, you should be able to do this
from torchrl.data import ReplayBuffer, ListStorage rb = ReplayBuffer(batch_size=4, collate_fn=lambda x: x, storage=ListStorage(10)) s, a, sp, r, d = range(5) rb.add((s, a, sp, r, d)) s, a, sp, r, d = range(5, 10) rb.add((s, a, sp, r, d)) s, a, sp, r, d = range(10, 15) rb.add((s, a, sp, r, d)) print(rb.sample())
(removing the collate function will cause the buffer to call torch.stack on the tuples, which won’t work. The ListStorage is just there to show you how to control the size of the buffer)
which results in
[(5, 6, 7, 8, 9), (0, 1, 2, 3, 4), (0, 1, 2, 3, 4), (5, 6, 7, 8, 9)]
ie 4 samples of your replay buffer.
If you want to make the best of the replay buffer, use
TensorDictReplayBuffer with a LazyTensorStorage, which will be much faster to write and sample.
Hope that helps!