I did a bit of experimentation on how torch makes use strided memory layout to save memory, and wondered about two things, that I think are related. First the simple one:
- tf.tile could make use of strided layout, but does not:
import torch
base = torch.tensor([[1, 2, 3]]) # [1, 3]
expanded = base.expand(100, 3) # [100, 3]
assert expanded.storage().size() == 3
assert expanded.stride() == (0, 1)
tiled = torch.tile(base, [100, 1]) # [100, 3]
assert (tiled == expanded).all()
assert tiled.storage().size() == 300
assert tiled.stride() == (3, 1)
I am aware that not every tiling operation can be represented in some strided layout, but it seems like it should be not to difficult to figure out the edge cases and do it in those cases more efficiently. Might be, that some developers heavily prefer tiling over expanding even though it results in the same tensor. This would also be in line with the view vs reshape philosophy, where reshape also reverts to the view behavior if it is possible.
- Now to my second, more general question, that brought me to the top one:
It seems to me there is a missed opportunity to generalize the idea of a strided layout to avoid the copies of data in all cases of a tile operation. To get more concrete, I want to motivate this by outlining an approach I had in mind.
I have an encoder network, that encodes input samples into a latent representation, resulting in a tensor shape [B, E], where B is the batch dimension, and E is some fairly large latent representation (e.g. output of CNN on an image). Now I want to use a query architecture, where I have per sample multiple queries (e.g. N queries with dimension Q) that a smaller & faster network q
should use, to grep data from the encoding space. My query function takes [B’, E] & [B’, Q] as two input tensors and computes some [B’, O] output. B’ = B x N and all my input queries would come in the form of [B, N, Q] which I can simply reshape to [B’, Q]. However, for my encoding, I would like to do
encoding = encoding[:, None, :].expand(B, N, E)
encoding = encoding.view(B', E)
however, the view operation fails, as it is not able to represent the data in that shape and with the given strides without replicating the data. With N and E being possibly quite large (e.g. N>>B) it would be nice to safe on memory here.
I only did a quick brainstorming, but do not immediatly see a flaw in the following idea:
What if we had next to a stride number, also an “inverse” stride (I am sure I am not the first to think of this, please tell me what the correct word for this is).
The normal strides work, as I understand it, such that if I want to look up the value at tensor position [2, 9, 0] which has strides (50, 5, 1) and shape (14, 10, 5), the storage index would be 2x50 + 9x5 + 0x1 = 145. I am now additionally thinking of “inverse” strides being in a traditional layout simply the tensor shape, e.g. (14, 10, 5). and it would be taken as modulo operator before multiplying with the stride. E.g. (2%14)x50 + (9%10)x5 + (0%5)x1 = 145 still the same storage index. However, if for example my data was copied along the first dimension twice (e.g. original shape of the 14 was [7,2]) then I could represent this with an inverse stride of (7, 10, 5) resulting in an storage index of (2%7)x50… = 145 still the same, but now the storage index to element [2+7, 9, 0] would be the same as for [2, 9, 0]. This kind of generalizes the stride of 0 that is produced by torch.expand (expand could now keep the stride at 1, instead of 0, but set the module “inverse” stride to 1, more clearly signaling what is going on).
And I believe this would also allow every tiling operation to be a simple modification of strides and inverse strides without using any copy operation of raw data. Am I right or am I missing something here?
I am actually not that deep in the field of memory layouts, so any literature and reason why pytorch has not adopted this, would be greatly appreciated, as I have difficulty finding the right resources on this “dry” topic on my own