Collating named tensors

The named tensor functionality is very promising for the models I use: I work on inverse problems where tensors can have many channels which need to be permuted quite often to fit the operators (such as the torch fft). Having the dimensions named would make life much easier if you e.g. want to sum over a certain named dimension.

However, the API is unstable, and not yet complete. I am looking at what the proper / most elegant way to implement this. I took a rather straightforward approach: in my Dataset class I have transforms which name the dimensions, and when some functions do not support the names, I save them, drop them, apply the operation and add the names again. My datasets typically output dictionaries with keys such as target, image and so on, which are in my case therefore named.

All good. However, when you want to collate these tensors and add a key batch to the first dimension. Unfortunately stack does not yet support named tensors. So I have tried the following, and rewrite the collate_fn:

if isinstance(elem, torch.Tensor):
    out = None
    # TODO: Named tensor: once stack supports named tensors, drop the rename.
    names = elem.names
    elem = elem.rename(None)
    if torch.utils.data.get_worker_info() is not None:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum([x.rename(None).numel() for x in batch])
        storage = elem.storage()._new_shared(numel)
        out = elem.new(storage)
    out_batch = torch.stack([_.rename(None) for _ in batch], 0, out=out)
    if any(names):
        new_names = tuple(['batch'] + list(names))
        out_batch = out_batch.refine_names(*new_names)
    return out_batch

Unfortunately that does not work yet.
RuntimeError: pin_memory is not yet supported with named tensors. Please drop names via `tensor = tensor.rename(None)`, call the op with an unnamed tensor, and set names on the result of the operation.

The solution here would be to rewrite the DataLoader too, and add this functionality there. If I do this, I am rewriting so much of the core functionality of pytorch (using 1.5) I am wondering: Is this the intended usage of named tensors, or are they meant to be used in a different place? Is my approach is fine, is there a more convenient way to do all this without having to rewrite everything (e.g. add the names as a key in my dictionary)?

Hi Jonas,

sorry to hear you are trouble with named tensors.
To the best of my knowledge your approach should be correct and these methods are just not implemented yet.

Would the workaround work of returning a dict from your Dataset and add the names afterwards inside the DataLoader loop?

Hi Patrick,

The collate function as I propose above does work if you disable memory pinning which is for my application not a bottleneck (assuming that this will be implemented in the future). You do run into different issues if you do this and have multiple workers:

RuntimeError: NYI: Named tensors don’t support serialization. Please drop names via tensor = tensor.rename(None) before serialization.

This is trickier to handle as it is likely somewhere deep in the DataLoader, I will attempt the method of adding names through the dict, will report when and if that works!

1 Like

This is a solution (hacky) that does work

class DropNames:
    def __init__(self):
        pass

    def __call__(self, sample):
        new_sample = {}

        for k, v in sample.items():
            if isinstance(v, torch.Tensor) and any(v.names):
                new_sample[k + '_names'] = ';'.join(v.names)  # collate_fn will do funky things without this.
                v = v.rename(None)
            new_sample[k] = v

        return new_sample


class AddNames:
    def __init__(self, add_batch_dimension=True):
        self.add_batch_dimension = add_batch_dimension

    def __call__(self, sample):
        names = [_[:-6] for _ in sample.keys() if _[-5:] == 'names']
        new_sample = {k: v for k, v in sample.items() if k[-5:] != 'names'}

        for name in names:
            if self.add_batch_dimension:
                names = ['batch'] + sample[name + '_names'][0].split(';')

            else:
                names = sample[name + '_names'].split(';')

            new_sample[name] = sample[name].rename(*names)

        return new_sample

The first function can be added to the end of all the transforms (which are typically composed). This will drop the names, create string out of it separated by ; (assuming your variables have no ;).
The reason I do this, is because the collate function can do weird things with a list, and if you use the second function as:

add_names = AddNames(add_batch_dimension=True)
    for iter_idx, data in enumerate(data_loader):
    data = add_names(data)

you get the names in the tensors back. You might need to drop them again later, as many operations do not yet support this, but in my case, I make a lot of use of the FFT which in the current implementation requires you to permute the axis often. Having Named Tensors really helps in reducing bugs.

So the above works with pytorch 1.5, so perhaps it is of use to anyone. Improvements are definitely welcome.

1 Like