Hi Yiftach!
This will depend on the details of where your data comes from.
With .as_strided()
, the stride
you use is relative to the underlying
storage of the tensor you call it on, rather than relative to the tensor
itself (that is, how the tensor displays), which, itself, may have a
non-trivial stride
.
(I am using “trivial” or “default” stride
to mean what pytorch calls a
“contiguous” tensor.)
In any event, the stride
you pass to .as_strided
isn’t correct. However,
the correct stride
will depend on what stride
data
has.
In the example you posted, data
, being the result of a call to .permute()
,
does have a non-trivial stride, so that affects the stride
you need to use
to get your desired result.
Consider:
>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> data = torch.arange(10).repeat(2, 4, 4, 1).permute(0, 3, 1, 2)
>>> indices = torch.tensor ([
... [0, 1, 2, 3, 4],
... [2, 3, 4, 5, 6],
... [4, 5, 6, 7, 8]
... ])
>>>
>>> data._is_view() # data is a view (into the contiguous tensor returned by .repeat())
True
>>> data.is_contiguous() # and so is not contiguous
False
>>> data.stride() # and has "non-default" strides
(160, 1, 40, 10)
>>>
>>> dataA = data[:, indices] # (the new tensor returned by indexing is contiguous)
>>>
>>> rows = 3
>>> cols = 5
>>> dataB = data.as_strided ((2, rows, cols, 4, 4), (1, 2, 1, 1, 1)) # the strides are wrong
>>>
>>> torch.equal (dataB, dataA) # so not equal
False
>>>
>>> dataCont = data.contiguous() # make a contiguous version of data
>>>
>>> dataCont.is_contiguous() # indeed contiguous
True
>>> dataCont.stride() # with "default" strides
(160, 16, 4, 1)
>>>
>>> dataC = dataCont.as_strided ((2, rows, cols, 4, 4), (160, 32, 16, 4, 1)) # strides are straightforward
>>>
>>> torch.equal (dataC, dataA) # works
True
>>>
>>> dataD = data.as_strided ((2, rows, cols, 4, 4), (160, 2, 1, 40, 10)) # strides are a little trickier
>>>
>>> torch.equal (dataD, dataA) # also works
True
Best.
K. Frank