Avoiding data copy when using array indexing

In the following toy example:

import torch
data = torch.randn(10, 1024)
indices = torch.tensor([0, 1, 7, 0])
selected_data = data[indices]

Is there a way to avoid a copy by having selected_data be a view of data and share the same underlying storage?

In my real case, as in this toy example, the indices in indices may have repeating and non following elements. However, I assume each element of data is large and indices contains many repeats.

I’ve seen other possible choices like index_select() but all seem to copy data. Can this be avoided?

Bonus: my indices actually have a more specific structure, with k<n, it is [[0, 1 , …, n], [k, k + 1, …, n+k], [2k, 2k + 1, …, n+2k], …]. Does this somehow help? I assume that for [0, 1 , …, n] only it would, but the complete structure is more exotic.

Thanks!

Hi Yiftach!

No, not with general index values.

Pytorch tensors requires that their elements be stored contiguously (although
you can “stride” through the contiguous elements), so when you index into a
tensor you have to create a new contiguous tensor, rather than an indexed
view into the original tensor.

You can use .untyped_storage().data_ptr() to verify that new memory
is being used.

For this specific use case your indices have a structure that makes it possible
to create a strided view into your original tensor, rather than use explicit
indexing.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> data = torch.tensor (list (range (10)))
>>> indices = torch.tensor ([0, 1, 7, 0])
>>> selected_data = data[indices]
>>>
>>> data.untyped_storage().data_ptr()                  # original storage for data
1811627391296
>>> selected_data.untyped_storage().data_ptr()         # new storage for indexed data
1811621701056
>>>
>>> n = 4
>>> k = 2
>>>
>>> indices = torch.tensor ([
...     [0, 1, 2, 3, 4],
...     [2, 3, 4, 5, 6],
...     [4, 5, 6, 7, 8]
... ])
>>>
>>> rows = (10 - (n + 1)) // k + 1
>>> cols = n + 1
>>>
>>> dataA = data[indices]
>>> dataB = data.as_strided ((rows, cols), (k, 1))
>>>
>>> data.untyped_storage().data_ptr()                  # original storage for data
1811627391296
>>> dataA.untyped_storage().data_ptr()                 # index -- new storage
1811578654848
>>> dataB.untyped_storage().data_ptr()                 # as_strided() -- view into original storage
1811627391296
>>>
>>> data[4] = 99                                       # modify original data
>>> dataA                                              # doesn't change new tensor
tensor([[0, 1, 2, 3, 4],
        [2, 3, 4, 5, 6],
        [4, 5, 6, 7, 8]])
>>> dataB                                              # change is reflected in view of original tensor
tensor([[ 0,  1,  2,  3, 99],
        [ 2,  3, 99,  5,  6],
        [99,  5,  6,  7,  8]])

Note that you have to be cautious when using such a view. Quoting from the
as_strided() documentation:

Warning

Prefer using other view functions, like torch.Tensor.expand(), to setting a view’s strides manually with as_strided, as this function’s behavior depends on the implementation of a tensor’s storage. The constructed view of the storage must only refer to elements within the storage or a runtime error will be thrown, and if the view is “overlapped” (with multiple indices referring to the same element in memory) its behavior is undefined.

In the example you give, your view is overlapped. My guess is that if you
treat dataB as read-only, you will be okay, but if you assign into dataB, you
may get non-deterministic results for the values of the overlapped elements.

Best.

K. Frank

1 Like

That perfectly answers all of my concerns. Thanks!

As a followup for the bonus case, what if my data is not one dimensional but n-dimensional? I’m working on videos where each one is of shape (channels, length, width, height) and normally I would do video[:, indices] but playing with as_strided doesn’t give me what I want - I’m modifying the size parameter to add C before and W, H after, and not sure what to add in the stride parameter:

import torch
data = torch.arange(10).repeat(2, 4, 4, 1).permute(0, 3, 1, 2)
indices = torch.tensor ([
     [0, 1, 2, 3, 4],
     [2, 3, 4, 5, 6],
     [4, 5, 6, 7, 8]
 ])
dataA = data[:, indices]

n = 4
k = 2
rows = (10 - (n + 1)) // k + 1
cols = n + 1
dataB = data.as_strided ((2, rows, cols, 4, 4), (1, k, 1, 1, 1))

dataB - dataA # not an array of zeros

So I wanted to make sure the option to share the view is still possible in this case.I would assume it is, but haven’t gotten it to work yet.

Hi Yiftach!

This will depend on the details of where your data comes from.

With .as_strided(), the stride you use is relative to the underlying
storage of the tensor you call it on, rather than relative to the tensor
itself (that is, how the tensor displays), which, itself, may have a
non-trivial stride.

(I am using “trivial” or “default” stride to mean what pytorch calls a
“contiguous” tensor.)

In any event, the stride you pass to .as_strided isn’t correct. However,
the correct stride will depend on what stride data has.

In the example you posted, data, being the result of a call to .permute(),
does have a non-trivial stride, so that affects the stride you need to use
to get your desired result.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> data = torch.arange(10).repeat(2, 4, 4, 1).permute(0, 3, 1, 2)
>>> indices = torch.tensor ([
...      [0, 1, 2, 3, 4],
...      [2, 3, 4, 5, 6],
...      [4, 5, 6, 7, 8]
...  ])
>>>
>>> data._is_view()                # data is a view (into the contiguous tensor returned by .repeat())
True
>>> data.is_contiguous()           # and so is not contiguous
False
>>> data.stride()                  # and has "non-default" strides
(160, 1, 40, 10)
>>>
>>> dataA = data[:, indices]       # (the new tensor returned by indexing is contiguous)
>>>
>>> rows = 3
>>> cols = 5
>>> dataB = data.as_strided ((2, rows, cols, 4, 4), (1, 2, 1, 1, 1))           # the strides are wrong
>>>
>>> torch.equal (dataB, dataA)     # so not equal
False
>>>
>>> dataCont = data.contiguous()   # make a contiguous version of data
>>>
>>> dataCont.is_contiguous()       # indeed contiguous
True
>>> dataCont.stride()              # with "default" strides
(160, 16, 4, 1)
>>>
>>> dataC = dataCont.as_strided ((2, rows, cols, 4, 4), (160, 32, 16, 4, 1))   # strides are straightforward
>>>
>>> torch.equal (dataC, dataA)     # works
True
>>>
>>> dataD = data.as_strided ((2, rows, cols, 4, 4), (160, 2, 1, 40, 10))       # strides are a little trickier
>>>
>>> torch.equal (dataD, dataA)     # also works
True

Best.

K. Frank