The named tensor functionality is very promising for the models I use: I work on inverse problems where tensors can have many channels which need to be permuted quite often to fit the operators (such as the torch fft). Having the dimensions named would make life much easier if you e.g. want to sum over a certain named dimension.
However, the API is unstable, and not yet complete. I am looking at what the proper / most elegant way to implement this. I took a rather straightforward approach: in my Dataset
class I have transforms which name the dimensions, and when some functions do not support the names, I save them, drop them, apply the operation and add the names again. My datasets typically output dictionaries with keys such as target
, image
and so on, which are in my case therefore named.
All good. However, when you want to collate these tensors and add a key batch
to the first dimension. Unfortunately stack
does not yet support named tensors. So I have tried the following, and rewrite the collate_fn
:
if isinstance(elem, torch.Tensor):
out = None
# TODO: Named tensor: once stack supports named tensors, drop the rename.
names = elem.names
elem = elem.rename(None)
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.rename(None).numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
out_batch = torch.stack([_.rename(None) for _ in batch], 0, out=out)
if any(names):
new_names = tuple(['batch'] + list(names))
out_batch = out_batch.refine_names(*new_names)
return out_batch
Unfortunately that does not work yet.
RuntimeError: pin_memory is not yet supported with named tensors. Please drop names via `tensor = tensor.rename(None)`, call the op with an unnamed tensor, and set names on the result of the operation.
The solution here would be to rewrite the DataLoader too, and add this functionality there. If I do this, I am rewriting so much of the core functionality of pytorch (using 1.5) I am wondering: Is this the intended usage of named tensors, or are they meant to be used in a different place? Is my approach is fine, is there a more convenient way to do all this without having to rewrite everything (e.g. add the names as a key in my dictionary)?