[RFC] TorchRL Replay buffers: Pre-allocated and memory-mapped experience replay

TorchRL Replay buffers: Pre-allocated and memory-mapped experience replay

TL;DR: We introduce a new memory-mapped storage for Replay Buffers that allows to store a large amount of data across workers and nodes, with low-latency indexing and writing.

Replay buffers are a popular feature of reinforcement learning libraries, so much so that independent packages have been dedicated to them. Getting this right of the uttermost importance to us.

We identify the following features we think researchers will be looking for in a replay buffer:

  • Sampling efficiency : getting the indexes of the items to sample should be fast, even with very large buffers (=> prioritized replay buffers).
  • Indexing efficiency : Indexing and returning a usable stack of elements from the replay buffer should be fast. Indexing should also be easy (e.g. storage[idx] or similar).
  • Storage modularity : it should be possible to easily choose where to store the elements stored in the buffer, e.g. in memory, on disk, perhaps on cuda etc.
  • Sampling modularity : integrating new samplers should be doable with little effort.
  • Distributed access : it should be possible for multiple workers/nodes to access a common replay buffer and write items into it.

As we prepare for a beta release of the library, we thought that it would be important to stabilize the replay buffer API as we can still introduce bc-breaking changes in the library. Once we’ll be at the beta stage and after, our API should be stable enough such that users can rely on it without worrying about upcoming upgrades. Hence, we are doing our best to tick all the above-mentioned boxes as soon as can be.

Regarding sampling efficiency , we have been blessed with the rlmeta c++ implementation of a SegmentTree structure that makes sampling easy and efficient in prioritized settings. The ReplayBuffer class associated with is also has some great features, such as multi-threaded sampling etc.

As of now, we have one dedicated replay-buffer class per sampling strategy . This means that adding a new sampler will require implementing a new RB class, which may be suboptimal as a great deal of the existing features will remains identical. We are contemplating the idea of making our replay buffer more similar to Reverb’s ones (and, for what it’s worth, to the Dataloader(dataset, sampler, collate_fn) API, except that adding elements to the buffer must be implemented as well). This would make the buffer class behave as

buffer = ReplayBuffer(sampler=sampler, storage=storage, collate_fn=collate_fn)

and in the future a remover can be added as well (feel free to upvode that feature if you feel it is of the uttermost importance to get this fast!)

Storage and indexing

Storage and indexing are trickier to get right: until now, our storage simply consisted in adding items one at a time in a list, from which items were indexed and stacked using an ad-hoc collate_fn . This has the notable advantage of being robust to samples of different sizes or dtypes (e.g. truncated trajectories etc), as the assumptions about the data coming in was minimal. However, it requires us to index a list and then process the collected items accordingly, which is suboptimal. The other inconvenient with this approach is that it can be hard to foresee what will be the memory requirement of the buffer: if images or videos are collected, it may very well happen that the training script will be interrupted after some time due to an out-of-memory issue, even more so if the collected items have a variable size.

giphy (2)

To solve these issues, we introduce a new storage class, LazyMemmapStorage , which stores tensors on disk using a tensor-like class (MemmapTensor ) that is an interface between np.memmap and torch.Tensor . np.memmap has significant advantages: first, it is easy to specify where the data should be stored. In many cases, the scratch space location is tailored to the infra and cluster requirements, making a default /tmp/file location unsuited. Second, indexing from np.memmap is easy and efficient: those arrays can be accessed at multiple different indices without much overhead. Third, the upper limit of MemmapTensors opened at the same time is looser than the number of shared-memory tensors for instance. Finally, storing these data on disk rather than memory makes it easy to retrieve data after a job failure. However, MemmapTensors do not support differentiation, but this is an unusual feature for tensors stored in a replay buffer anyway.

The API can now be sketched in the following way (for a single tensor):

buffer = ReplayBuffer(
    cfg.buffer_size,
    collate_fn=lambda tensors: tensors,
    storage=LazyMemmapStorage(cfg.buffer_size)
)

As the name indicates, the storage is lazy in the sense that it will be populated once it reads the first tensor that it is given.

>>> buffer.add(torch.randn(3))  # `add` adds a single element to the buffer 
Creating a MemmapStorage...
The storage was created in /tmp/tmpz_3dpje2 and occupies 0.00011444091796875 Mb of storage.
>>> buffer.sample(1).shape
torch.Size([1, 3])
>>> buffer._storage[2]  # not assigned yet
torch.tensor([0., 0., 0.])

This allows for easy-to-generalize buffer initialization across algorithms and environments. In our implementation, indexing strategy is the responsibility of the storage class: unlike the regular ListStorage , LazyMemmapStorage will not access the indexed items one at a time and stack them using collate_fn . Instead, it will create a stack of indices such that the memory-mapped array is indexed only once. In practice, this brings a speed-up of one or two orders of magnitude.

Of course, usually replay buffers need not to store one tensor category but a lot more of them. For these cases, TorchRL provides a TensorDict class that casts common operations such as reshaping, moving to device and indexing across tensors stored in a dictionary. Of course, LazyMemmapStorage supports TensorDict s inputs seamlessly:

>>> td = TensorDict({"a": torch.randn(10, 3), "b": torch.ones(10, 1)}, batch_size=[10])
>>> storage = LazyMemmapStorage(10, scratch_dir="/scatch/")
>>> buffer = ReplayBuffer(
...    10,
...    collate_fn=lambda x: x,  # we don't need to do anything, TensorDict will do it for us
...    storage=storage)
>>> buffer.extend(td)  # `extend` adds multiple elements to the buffer (unlike `add`) 
Creating a MemmapStorage...
The storage is being created: 
	a: /tmp/tmp1m7m9uig, 0.00011444091796875 Mb of storage (size: [10, 3]).
	b: /tmp/tmpsilpr0kg, 3.814697265625e-05 Mb of storage (size: [10, 1]).
>>> sample = buffer.sample(3)
>>> sample
SubTensorDict(
    fields={
        a: MemmapTensor(torch.Size([3, 3]), dtype=torch.float32),
        b: MemmapTensor(torch.Size([3, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
>>> sample.get("b")
tensor([[1.],
        [1.],
        [1.]])

Again, the collate function is the identity as the only thing we need to do to gather a stack of samples is to index the corresponding tensordict (under the hood, the only operation that LazyMemmapStorage does is out_tensordict = self.memmap_tensordict[idx] ).

Performance

We tested the replay buffer sampling speed across 3 types of storage (list, pre-allocated tensors, pre-allocated memory-mapped tensors), and across 3 types of data input (simple 3 x 4 floating point tensors, dictionary of tensors, and corresponding TensorDict). The test consisted in populating a RB with 100K examples and sampling from it batches of 128 elements. The test file can be found attached to this note. The complexity of storing and indexing simple tensors (left group on the figure) is obviously lower as there is only a single data source. Storing dictionaries (middle group) was only tested using the regular ListStorage as dictionaries cannot be indexed (unlike TensorDicts ): in other words, improving dictionaries to make them indexable and using them with LazyTensorStorage and LazyMemmapStorage is the purpose of the TensorDict class.

289006876_339996834971098_9211671885180231905_n

We can observe that sampling a stored, single, pre-allocated tensor (right group, either in memory – orange – or on disk – green) is much faster than sampling from a list and stacking elements together (right group, blue): in other words, Memmap and pre-allocated tensors are faster than lists as expected. Also, the overhead of storing items on disk is negligeable but this depends on the storage specs. Nevertheless, notice that TensorDict does a decent job with the ListStorage already compared to the naive dictionaries (middle vs right blue bar) :upside_down_face:.

Finally, regarding the distributed side of things, LazyMemmapStorage paves the way to shared storages across can be placed anywhere on the available partitions: if nodes have access to a common space with reasonably fast common storage, they will all be capable of writing items to the buffer and access it at low cost. Stay tuned to hear more about this!

1 Like

Here’s the code to generate the figure:

import pandas as pd
import seaborn as sbn
from matplotlib import pyplot as plt

from torchrl.data.replay_buffers.storages import LazyMemmapStorage, \
    ListStorage, LazyTensorStorage
from torchrl.data import TensorDict, TensorDictReplayBuffer, ReplayBuffer
import torch
import timeit


def collate_fn_dict(list_of_dict):
    keys = list_of_dict[0].keys()
    d = dict()
    for k in keys:
        d[k] = torch.stack([_d[k] for _d in list_of_dict], 0)
    return d


def test(prefetch, L=100000, N=100000, dtype="tensor", storage="list"):
    print("dtype: ", dtype, "// storage: ", storage)
    if storage == "list":
        storage = ListStorage()
        collate_fn = lambda x: torch.stack(x, 0)
    elif storage == "memmap":
        storage = LazyMemmapStorage(L)
        collate_fn = lambda x: x
    elif storage == "tensor":
        storage = LazyTensorStorage(L)
        collate_fn = lambda x: x
    else:
        raise NotImplementedError
    if dtype == "tensor":
        rb1 = ReplayBuffer(
            L, storage=storage,
            prefetch=prefetch,
            collate_fn=collate_fn
        )
    elif dtype == "dict":
        rb1 = ReplayBuffer(
            L, storage=storage,
            prefetch=prefetch,
            collate_fn=collate_fn_dict
        )
    else:
        rb1 = TensorDictReplayBuffer(
            L,
            storage=ListStorage(),
            prefetch=prefetch,
            collate_fn=collate_fn
        )

    # fill them
    for _ in range(L // N):
        if dtype == "td":
            td = TensorDict({
                'a': torch.randn(N, 3, 4),
                'b': torch.randn(N, 3, 4),
                'c': torch.zeros(N, dtype=torch.bool),
            }, [N]
            )
        elif dtype == "dict":
            td = [{
                'a': torch.randn(1, 3, 4),
                'b': torch.randn(1, 3, 4),
                'c': torch.zeros(1, dtype=torch.bool),
            } for _ in range(N)]
        else:
            td = torch.randn(N, 3, 4)
        rb1.extend(td)
    t = timeit.timeit("rb1.sample(128)", globals={"rb1": rb1}, number=10000)
    print(t)
    return t


array = torch.zeros(3, 3)
for i, storage in enumerate(["list", "tensor", "memmap"]):
    for j, dtype in enumerate(["tensor", "dict", "td"]):
        if dtype == "dict" and storage != "list":
            continue
        t = test(0, dtype=dtype, storage=storage)
        array[i, j] = t
df = pd.DataFrame(array.numpy(), columns=["torch.Tensor", "dict", "TensorDict"], index=["ListStorage", "LazyTensorStorage", "LazyMemmapStorage"])
df = df.stack(0).rename_axis(("storage", "datatype")).reset_index()
df = df.rename(columns={0: "time (s)"})
print(df)
sbn.barplot(x="datatype", y="time (s)", hue="storage", data=df)
plt.show()
1 Like

Hello, this is indeed a great addition, and extremely useful for RL practitioners.

i have been trying to use TensorDict and TensorDictPrioritizedReplayBuffer but I seem to be having an issue with the sample index. I hope it is the right place to post such a question. I am not sure if it is me doing something wrong or some mismatch with torchrl. More specifically, for a DQN application, i initialize my memory here:

memory = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, eps=0.01, storage=LazyTensorStorage(REPLAY_SIZE), batch_size=BATCH_SIZE)

with REPLAY_SIZE = 10000 and BATCH_SIZE = 5.
then in each iteration i create my S-A-R-S’ data and extend the memory with the data (i also print index and initial priority)

data = TensorDict({“state”: state.cpu(), “action”: action_indx_torch.unsqueeze(0).cpu(), “next_state”: next_state, “reward”:rew_torch.unsqueeze(0).cpu()}, [1])

In the data TensorDict i use [1] because if i set [REPLAY_SIZE] i get following error: RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([10000]) and value.shape[:self.batch_dims]=torch.Size([1])

indx = memory.extend(data)

print(‘indx=’,indx)

print(memory._get_priority(data))

Then i sample with batch_size. I use the _sample instead of sample to see the weights

batch_sample, info = memory._sample(BATCH_SIZE)

Then i print the index of the batch, the priority, and the action-key for the samples, and finally i set and update the priority for the batch.

print(“index”, batch_sample[“index”])
print(“priority”, [memory._get_priority(td) for td in batch_sample])

for td in batch_sample:
print(‘samples’,td[‘action’])

priority = torch.sum(torch.abs(ESAV - SAV),dim=0)

give a TD0 priority to these samples
batch_sample.set(“td_error”, priority)

and update priority
memory.update_tensordict_priority(batch_sample)

=============================================
Here is the output at the 20-th iteration

indx= [19] %correctly extends memory with new index 19 for the new data
1.4470902392888914 %priority of new data point

info= {‘_weight’: array([0.7642294, 0.7642294, 0.9964536, 0.7642294, 0.7642294], dtype=float32)} %after sampling with BATCH_SIZE=5 we get these weights

PROBLEM IS HERE: When i print the index for the batch sample of size 5, i always get the same index 0 and the same priority using _get_priority, because we call the priority using the index.

index tensor([0, 0, 0, 0, 0], dtype=torch.int32)
priority_key [1.4470902392888914, 1.4470902392888914, 1.4470902392888914, 1.4470902392888914, 1.4470902392888914]

However 5 samples for key-action in the batch are sampled correctly. The samples are indeed different inside the batch, so i do not understand why the index of each sample is 0 (and the priority is the same)

samples tensor([8, 5, 6])
samples tensor([2, 2, 4])
samples tensor([8, 2, 7])
samples tensor([7, 1, 9])
samples tensor([7, 9, 7])

What would be the reason for the wrong indexing? Thank you very much in advance.

Thanks for reporting this
There was such a bug a couple of weeks ago and I fixed it in [BugFix,Feature,Doc] Fix replay buffers sampling info, docstrings and… · pytorch/rl@e26d148 · GitHub
Can you check that the index is still messed up with the latest torchrl-nightly?

I tried to run (a version of) this code.
Here I extend twice the replay buffer, although you could have a non-dimensional tensordict and call memory.add instead. (notice also that I call data.cpu() to cast the elements on cpu).

import torch

from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
from tensordict import TensorDict

BATCH_SIZE = 10
REPLAY_SIZE = 1000
memory = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, eps=0.01, storage=LazyTensorStorage(REPLAY_SIZE), batch_size=BATCH_SIZE)

rew_torch, state, action_indx_torch, next_state = torch.randn(4, 1, 10)
data = TensorDict({"state": state, "action": action_indx_torch.unsqueeze(0), "next_state": next_state, "reward":rew_torch.unsqueeze(0)}, [1]).cpu()

# there is one element in the replay buffer to sample from
indx = memory.extend(data)
# there are 2 elements to sample from
indx = memory.extend(data)

batch_sample = memory.sample(batch_size=BATCH_SIZE)
print(batch_sample["index"], indx)

for td in batch_sample:
    ESAV, SAV = torch.rand(2, 1, *batch_sample.shape)
    priority = torch.sum(torch.abs(ESAV - SAV), dim=0)
    batch_sample.set("td_error", priority)
    memory.update_tensordict_priority(batch_sample)

The index is a mix of 0 and 1 as expected (since there are 2 elements in the buffer, you can check that via len(memory).

I hope this helps, LMK if anything is unclear and/or missing from the doc!

Great! Thank you very much for your answer. I installed torchrl-nightly AND tensordict-nightly (otherwise it does not run) so now, when i call

transitions = memory.sample(batch_size=BATCH_SIZE)
print(“batch_sample[index]”, transitions[“index”])

i get correctly non-zero indices: e.g.

batch_sample[index] tensor([13, 6, 1, 13, 11])

However, i think there is another issue because when i want to check the priority weights

print(“priority”, [memory._get_priority(td) for td in batch_sample])

these priorities do not update! I keep getting the same vector over all iterations

priority [1.5434609531703256, 1.5434609531703256, 1.5434609531703256, 1.5434609531703256, 1.5434609531703256]

To update the priorities i use the commands per iteration

% give a TD0 priority to these samples
batch_sample.set(“td_error”, TD0)
% and update priority
memory.update_tensordict_priority(batch_sample)

Here, TD0 is a tensor with same size as batch_size, e.g.

TD0: tensor([0.6085, 1.3825, 1.3924, 0.6085, 1.0975], grad_fn=)

So I worry that since priorities are not updated the sampling from the Prioritized buffer will just be uniform, and no prioritization will be taken into account. Maybe there is something wrong with my code?

Another minor issue is that when i run (memory._sample)

transitions, info = memory._sample(BATCH_SIZE)
print(‘info=’,info)
print(“batch_sample[index]”, transitions[“index”])

I get the correct indicies in the info variable but the 0 indicies in the batch_sample.

info= {‘_weight’: array([0.99792635, 0.7333868 , 0.7333868 , 0.99792635, 0.7333868 ],
dtype=float32), ‘index’: array([ 2, 6, 11, 2, 9])}
batch_sample[index] tensor([0, 0, 0, 0, 0], dtype=torch.int32)

Not sure what your example is (can you paste a complete script within the proper formatting? That would really help!)

There is a “bug” which is that if you have the same index multiple times, the priority assigned will be the last in the tensordict.

Besides that, this work on my end (ie: after updating the priority of index 1 I always sample items with index 1):

import torch

from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
from tensordict import TensorDict

BATCH_SIZE = 10
REPLAY_SIZE = 1000
memory = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, eps=0.01, storage=LazyTensorStorage(REPLAY_SIZE), batch_size=BATCH_SIZE, prefetch=0)

rew_torch, state, action_indx_torch, next_state = torch.randn(4, 1, 10)
data = TensorDict({"state": state, "action": action_indx_torch.unsqueeze(0), "next_state": next_state, "reward":rew_torch.unsqueeze(0)}, [1]).cpu()

# there is one element in the replay buffer to sample from
indx = memory.extend(data.clone())
# there are 2 elements to sample from
indx = memory.extend(data.clone())

batch_sample = memory.sample(batch_size=BATCH_SIZE)
print(batch_sample["index"], indx)

td_error = torch.ones(*batch_sample.shape, 1)
td_error[batch_sample["index"] == 1] = 1e30
batch_sample["td_error"] = td_error
memory.update_tensordict_priority(batch_sample)

batch_sample2 = memory.sample(batch_size=10)
print(batch_sample2["index"])  # all 1!

1 Like

Another point: I would try to avoid using the private methods (IDK if you’re using them in your code or for debugging) but as they are “private” we may change them without raising a deprecation warning in the future, and hence break your workflow!

Thank you again for your answer. I validate that everything now works fine. I also did the suggested test to enforce a very high value of priority for a selected index (e.g. indx=1) and after that in the next iterations the batch_sample contains indices of the selected one (e.g all 1).

My confusion came from printing

print(“priority”, [memory._get_priority(td) for td in batch_sample])

which as you mention is a private function, and i was not supposed to be using this, but i did it for debugging reasons. This being said, for the sake of better modularity of code, it would be great if the memory.sample(batch_size=BATCH_SIZE) command would return the sample dictionary containing also the priority values and not only the indices (i am not sure what the difference is between ‘weights’ and ‘priorities’).

Thank you very much again for your immediate answer, it has been very helpful. Torchrl is a great initiative!

You are right, these ought to be fixed. Can you submit an issue with the papercuts you faced to help us prevent that for other users? That would be super useful!

1 Like