[RFC] torch.dict.TensorDict

[RFC] TensorDict

This PR proposes to integrate the core features of tensordict in torch core under torch.dict.
We welcome comments, contributions and suggestions!
To get a sense of what tensordict can do, you can either install it via PyPI (pip install tensordict) and/or try the PR directly.

Purpose

TensorDict is a data carrier for PyTorch. The problem it is trying to solve is to provide a uniform way of carrying more than one tensor in a code base without having to specify beforehand what entry is or is not present in the data structure, and operating on these as a batch.
More specifically, it solves the following problems in pytorch:

  • reduces boilerplate when the same op has to be applied to a batch of tensors (eg, samples from a dataset, parameters from a model)
  • isolates parameters from a model and enables operations on these independently from the model structure (eg, (a)sync communication, gathering, splitting/sharding etc), essentially decoupling the model (algebraic operations) from the data it operates over (incl. parameters)
  • better representation (print) of tensor structures
  • more straightforward functional API
  • intuitive parameter and data serialization, with efficient indexing via memmap

In brief, the goal is to be able to write generic data pipelines (transforms, preprocessing) as well as parameter manipulation (point-to-point communication, serialization, sharding) while avoiding to explicitly loop over the set of tensors that is being provided.

It is based on a couple of basic rules (think of it as axioms) from where all other functionalities follow naturally:

  1. A TensorDict behaves like a dictionary of tensors / TensorDict instances, as such it allows nesting;
  2. A TensorDict “shape” (or batch-size, in contrast with feature-size) is arbitrary, required upon construction and constrains all its content to share the same N first dimensions. This feature allows shape operations and indexing;
  3. TensorDict does not constrain the dtype in any case;
  4. TensorDicts have an optional device. If specified, it will cause all tensors stored in the tensordict to be mapped onto the device.

This simple set of rules completely defines TensorDict’s expected behaviour. For instance, given that a TensorDict has a shape, __getitem__ supports both key-based and shape-based indexing. The fact that multiple dtypes (as well as the feature size) can be stored and are not checked make the implementation of algebraic operations out-of-scope.
When nesting TensorDicts within each other, the root must always have at least as many matching dimensions as the inner nodes, etc.

Basic usage and concepts

A TensorDict can be created explicitly from a set of tensors or arrays:

data = TensorDict({
   "tensor": torch.zeros(3),
   "from_np": np.zeros((3,)),
   "from_int": 1,
   "from_list": [[1, 2, 3]],
   "from_dict": {"a": 1},
   }, batch_size=[]) # batch-size must always be specified
print(data)

which results in

TensorDict(
    fields={
        from_dict: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        from_int: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        from_list: Tensor(shape=torch.Size([1, 3]), device=cpu, dtype=torch.int64, is_shared=False),
        from_np: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
        tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

or directly from a module containing parameters:

net = nn.Sequential(nn.Linear(3, 4))
params = TensorDict.from_module(net)
print(params)

which results in

TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

A key concept in TensorDict is NestedKey: a NestedKey is anything from a regular string to a tuple of tuples of strings.
For instance, in the above example, data.get("0"), data.get(("0",)) and data.get((("0",),)) all return the same value.
The reason we allow nested tuples becomes apparent when we need to append keys one after another: say you want to get the weight in above example, but the code that treats that is only given a prefix which can be a tuple or a string (ie, a NestedKey):

def get_weight(params, prefix):
    return params.get((prefix, "weight))

Allowing for nested tuples avoids having to check for the prefix (or in other cases suffix) type when concatenating keys.

Features

Tensor-like features

Tensor-like features are focused on shape-based operators, device and dtype casting and distributed methods. Algebraic operations are ignored but can still be executed easily with tensordict.apply (see below)

  • shape-based operators: reshape, view, permute, transpose, flatten, unflatten
  • Indexing and masking: __getitem__ and __setitem__ support all indices supported by torch.Tensor. Masking can be particularily useful:
    data = TensorDict({"key1": torch.arange(12).view(3, 4, 1)}, batch_size=[3, 4, 1])
    mask = torch.zeros(3, 4, dtype=torch.bool).bernoulli_()
    data[mask]  # a tensordict of shape [mask.sum(), 1]
    
  • Augmented torch functions: torch.ones_like, torch.zeros_like, torch.full_like, torch.empty_like, torch.clone, torch.squeeze, torch.unsqueee, torch.masked_select, torch.permute, torch.where and torch.gather
  • Padding: one can pad a single tensordict along the batch-dimensions, but also stack a list of tensordict of different shapes with a certain padding value (see torch.dict.pad_sequence)
  • Stacking and concatenating: torch.stack and torch.cat have been augmented to work with tensordict instances.
  • Split, unbind, chunk
  • Casting: TensorDict.to(...) has the same signature as Tensor.to(...) while also accepting other TensorDictBase subtypes

Dict-like features

TensorDict comes with many dict-like features:

  • get(key, default=some_val) where key is a NestedKey (see above). The default value can be anything, not only a compatible type (tensor or tensordict);
  • pop(key, default=some_val)
  • setdefault(key, value)
  • del data[key]
  • data.rename_key_(key_origin, key_dest)
  • data.unflatten_keys(separator=".") which unflattens a tensordict along the key dimension (particularily useful when representing state_dicts: data = TensorDict(module.state_dict(), []).unflatten_keys("."))
  • data.flatten_keys(separator=".") which flattens a tensordict (particularily useful when one needs to represent parameters as a state_dict)
  • keys, values and items all accept (key)word arguments include_nested and leaves_only (both False by default to match dict behaviour) which will iterate over the sub-tensordicts, potentially ignoring non-tensor (ie, TensorDict) values.

Arbitrary functions

  • TensorDict.apply: allows you to compute a certain function on all the values of the tensordict (recursively). It accepts other tensordict instances too, allowing the implementation of algebraic operations OOB.
data_x.apply(lambda x, y: x+y, data_y) 
data.apply(lambda x: x.norm().cpu(), batch_size=[], device="cpu") # if the device or batch-size change, they can be specified here
  • TensorDict.map: Like HF’s datasets map function, it maps a function over the whole tensordict, across one given dimension:
              >>> import torch
              >>> from torch.dict import TensorDict
              >>>
              >>> def process_data(data):
              ...     data.set("y", data.get("x") + 1)
              ...     return data
              >>> if __name__ == "__main__":
              ...     data = TensorDict({"x": torch.zeros(()).expand(1, 1_000_000)}, [1, 1_000_000]).memmap_()
              ...     data = data.map(process_data, dim=1)
              ...     print(data["y"][:, :10])
              ...
    

Serialization

Disclaimer: Parts of this API is WIP as we’re making individual tensor serialization

One of the goals of tensordict is to get an efficient representation of data on disk, by tracking file name and location, while preserving a fast memmap-based indexing. This allows us to represent huge datasets composed of multiple tensors on disk. If part of the preprocessing can be done anticipatively, this can bring a tremendous speed-up on dataloading:
image

For parameter serialization, one can do something as simple as

net = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4))
params = TensorDict.from_module(net)
params.memmap_(prefix="/path/to/params")

which will result in the following data structure:

/path/to/params
│
├── 0
│ ├── weight.memmap
│ └── bias.memmap
│
└── 1
  ├── weight.memmap
  └── bias.memmap

Currently, these memory-mapped tensors are a subclass of torch.Tensor, but we are looking at a refactoring of this where these are plain tensors.

Food for thoughts and topic for discussion: how should this work with parameter tying (ie, since we copy tensors in one location corresponding to their location in the module, how does this work when the same parameter is present at multiple places within the module?)

Functorch compatibility

TensorDict is compatible with several features from functorch:

  • vmap: it is easy to vmap over a TensorDict and the shape is treated in the same way as a Tensor shape:

    td = TensorDict({"a": torch.zeros(3, 4, 5)}, batch_size=[3, 4])
    ones = torch.ones(4, 5)
    td_out = torch.vmap(lambda td, one: td.set("a", td.get("a") + one), in_dims=(0, None), out_dims=(1,))(td, ones)
    assert td_out.shape == torch.Size([4, 3]) # the last dim has been swapped, as per `out_dim` argument
    

    The core of the feature lies in the batch-size preservation of these operations, which requires TensorDict to be registered within the vmap-related functions.

  • functional_call (WIP): since TensorDict can efficiently represent parameters, it is natural to use it for functional calls with nn.Modules. Unlike the “regular” API, using TensorDict allows to vmap over the parameters in a more natural way (when working with model ensembles):

    params = torch.stack([TensorDict.from_module(module1), TensorDict.from_module(module2)], 0)
    y = vmap(lambda p, x: torch.functional_call(module1, p, x), (0, None))(params, x) 
    

    This has the advantage that the parameters are represented with the appropriate shape from the outside. Stacking them or unbinding them comes out-of-the-box and at no cost.
    On a low-level, tensordict recursive implementation avoids some overhead (handled by caching of nested modules) when setting parameters in a module, because it will assign the parameters to the modules one at a time by following the strucutre of these much more naturally than a flat dict would. This also makes the implementation (arguably) more straightforward.

  • First-class-dimension: Parts of a TensorDict batch-size can be hidden using first-class-dimension. For example, the example above can be simplified as such:

    from functorch import dim as ftdim
    params = torch.stack([TensorDict.from_module(module1), TensorDict.from_module(module2)], 0) # params has batch-size [2] 
    dim = ftdim.dims(1)
    params = params[dim] # params has now an empty batch-size
    torch.functional_call(module1, params, x)
    

PyTree compatibility

Additionally, TensorDict is compatible with pytree. This is aimed at making tensordict interchangeable with regular dict, but it is not used to implement tensordict functions (as they manipulate the batch-size, which has some inherent logic).

When using pytree, the batch-size will be preserved:

        td = TensorDict(
            {
                "a": TensorDict(
                    {
                        "b": TensorDict({"c": torch.ones(2, 3, 4)}, [2, 3]),
                        "d": torch.ones(2),
                    },
                    [2],
                ),
                "e": 1,
            },
            [],
        )
        td = tree_map(lambda x: x + 1, td)
        assert (td == 2).all()
        assert td.shape == torch.Size([])
        assert td["a"].shape == torch.Size([2])
        assert td["a", "b"].shape == torch.Size([2, 3])
        assert td["a", "b", "c"].shape == torch.Size([2, 3, 4])

This means that operations that affect the batch-size (eg, TensorDict.flatten) cannot be implemented using PyTree. Another issue with PyTree is that it deconstructs and reconstructs the content of the tensordict, which can cause some unwanted overhead.

Nested-tensor compatibility

Per se, nested tensors can be stored within a tensordict without any apparent issue:

nested = torch.nested.nested_tensor([torch.zeros(3), torch.ones(4)])
td = TensorDict({"nested_tensor": nested}, []) # no size works perfectly

nested = torch.nested.nested_tensor([torch.zeros(3), torch.ones(4)])
td = TensorDict({"nested_tensor": nested}, batch_size=[2]) # should work since we have a first dimension of shape [2]
print(td)

which prints

TensorDict(
    fields={
        nested_tensor: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Indexing a single item works fine

td[0] # gets the 0 valued tensor

but using other index types (eg, slices) raises an exception.

@tensorclass

Users may prefer a more strongly-typed class. This can come in handy when handling datasets, as predicting what fields are present in the data structure is sometimes key to write a transparent codebase.
To this aim, TensorDict comes with a specialised @dataclass decorator, @tensorclass.
@tensorclass works exactly like @dataclass except from the fact that a batch-size must be passed to it during construction. Unlike TensorDict, a tensorclass can have dedicated methods, and it supports non-tensor data:

@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor
    non_tensor: Any
    path_to_data: str
    def __post_init__(self):
        if self.path_to_data:
            self.memmap_(self.path_to_data)
    def get_data_with_label(self, label) -> MyData:
        return self[self.label == label]

images = torch.randint(255, (42, 3, 64, 64))
labels = torch.randint(100, (42, 3, 64, 64))

# creates a memory-mapped dataset
data = MyData(images, labels, batch_size=[42], non_tensor=None, path_to_data="/tmp")

When indexing a tensorclass instance with non-tensor data, the resulting non-tensor data will be identical to those in the parent object.

A great thanks to @apbard and @tcbegley for the help with this!

Printing data content

A minor yet useful feature is simply to see what a data structure looks like. Unlike the print of a torch.Tensor, a TensorDict representation will display only the tensor metadata such as shape, device and shared memory status.
When looking at a dataset, this can help to get a quick understanding of the content of it. For parameter batches, this representation is also more handy than printing a state_dict which is something hardly anyone would do in her right mind.

Miscellaneous

TensorDict supports a bunch of other features:

  • locking: a tensordict can be locked. This will ensure that any set operation that does not modify a tensor in-place is prohibited. There are two ways to (un)lock a tensordict:
    data = data.lock_()
    with data.unlock_():
        data.set("key", val)
    

and conversely. Using lock_() as a decorator comes in handy when one knows that the lock should be temporarily escaped.

  • caching results: in some cases, we know that the result of an operation will not change when called twice (for instance flatten_keys on a locked tensordict). In this case, the result is cached. All cached operations are very cheap in terms of memory requirements.

Utilisation

Data representation and manipulation

TensorDict is ideal to represent a dataset with fast indexing (provided that the data can be represented contiguously). Any sampling from a dataset boils down to a call to __getitems__:

# dummy training loop
for data in dataset:
    data = data.to(device, non_blocking=True)
    model(data) # provided that the model reads a tensordict, see tensordict TensorDictModule on tensordict lib for an example
    loss = loss_fn(data) 
    loss.backward()
    optim.step()
    optim.zero_grad()

The point of this example is that we can abstract away the content of the data and focus on the training loop itself. For instance, a training loop for image segmentation or classification wouldn’t look very different in this context.

It could be argued that TensorDict obfuscates the inner mechanisms of the training algorithm but we think it is the opposite, it focalises it on the relevant part (the model construction). On top of this, a print of the data during code execution is much more transparent than the same thing in a generic PyTorch script.

Dataloading

PyTorch dataloaders handle sampling multiple items from a dataset by either (1) spawning multiple parallel processes and executing the indexing independently using __getitem__ or (2) calling __getitems__ directly.
This makes TensorDict compelling: no need for a collate_fn, and if the data is contiguous, sampling is considerably faster than with mp (see figure above).

Data transforms and preproc

The TensorDict.map method makes TensorDict a complelling tool for preprocessing.

Similarly, to get a sense of what data transform looks like with tensordict, one can have a look at torchrl’s transforms which all have a pretty simple signature: transform(data), where data is a tensordict. With this, it is easy to imagine what a random crop that acts on an image and a mask would look like. In RL, we haven’t found (so far!) any use case where transforms composed over a bunch of tensors could not be solved with TensorDict. Recall that this comes with the rest of the services (eg, easy casting to device, good representation etc) which makes this feature much easier to blend in the code based than a regular dictionary.

Distributed features

TensorDict has a bunch of torch.distributed capabilities:

  • point-to-point communication is treated as whole, reducing any query to a single call (even for highly nested structures like parameters from a model). TensorDict.gather, TensorDict.isend, TensorDict.irecv, and other mirror their torch.distributed counterparts.
  • In RPC settings, MemoryMappedTensor provides a 3-10x speedup on node-to-node communication with a transparent communication pipeline (provided that the nodes have a shared physical storage).

Parameter serialisation

Loading part of a model with torch.save / torch.load is cumbersome. It also relies on pickle, which is unsafe. TensorDict doesn’t need any of this, and only reads and writes what is needed. The data representation on disk mirrors the data structure in the code. On top of this, TensorDict is compatible with torchsnapshot and we’re working on improving the compatibility between the two libs!

Modules that use TensorDict-stored parameters

TensorDictParams is a dedicated TensorDictBase subclass to host parameters and buffers.
It comes in handy whenever a TensorDict must be stored in a module and its parameters must be accessed:

params = TensorDictParams(TensorDict.from_module(module))
id_module = nn.Identity() # for the sake of simplicity
id_module.params = params
id_module.named_parameters() # contains all the params from `module` under `params` which behaves like a nn.Module

In some sense, TensorDictParams works as a nn.ParameterList but it supports nested and key-based features.

Remaining items to integrate

  • Extensive tests: WIP
  • doc

Open questions

torch.compile interaction

TensorDict compatibility with torch.compile is still an open question. In the library, we have dedicated classes (TensorDictModule and related), which go beyond the scope of this PR.
IMO it would already be great to support key-based and shape-based indexing: data.get("a_str"), data.get(("a_tuple",)), data.set("a_str"), data.get(("a_tuple",)), data[:3] etc.
This would already make it possible to build key-based models.
Making one step back, I think the main question is how a TensorDict should be used within a model: if we use it like in the tensordict lib, a model reads data from a TensorDict instance, executes some operations, and writes all the results that are needed in the TensorDict before returning it. If the usage is more like TensorDict -> Tensor mappings (which is way more restrictive IMO), things would change a bit.

torchscriptability

A bunch of users have asked for torchscriptability. I’m not against it, it would actually help with adoption in the industry, but given the status of torch script and the amount of refactoring this would require, the opportunity cost has to be weighted carefully.

Overhead

We made some good progress on the overhead of common ops (get, set and such), but there is always room for improvement. I would guess that torch.compile support for the feature would already make this problem less relevant.

Tensor-like interaction

Tensor-like classes (broadly speaking) interaction with TensorDict are not very well defined. In the lib, we make sure that KeyedJaggedTensors from torchrec can be placed in a tensordict and key/shape-indexed accordingly. Having a clear API to register other tensor-like objects would be awesome but it isn’t clear what that would look like in practice.

FSDP interaction

Intuitively, sharding a bunch of tensors with TensorDict seems “easy” to achieve (in the sense: if a common operation can be identified, it can be readily applied to a bunch of tensors). Similarly, all_gather and similar can be executed at every level of the model with little pain, and isolated over the relevant set of parameters every time they’re required.
This is all very theoretical and I’d need to get a better understanding of the mechanisms underpinning FSDP (+ we may not want to reinvent the wheel, given the amount of work this feature requires I’m keeping my excitement level at a low level :slight_smile: )

3 Likes