Restructuring repeated tensor data using as_strided

I’m trying to transform a tensor from this structure:

[[[
      [1,2,3,4,5,6,7,8,9,10,11,12],
      [99,98,97,96,95,94,93,92,91,90,98,97]
]]]

(shape 1 x 1 x 2 x 12)
to this structure

[[[
      [[1,2,3,4,5],
      [3,4,5,6,7],
      [5,6,7,8,9],
      [7,8,9,10,11]],
      [[99,98,97,96,95],
      [97,96,95,94,93],
      [95,94,93,92,91],
      [93,92,91,90,98]]
]]]

(shape 1 x 1 x 2 x 4 x 5)
I’ve been trying to use torch.as_strided() for this but I am not getting my desired output data the way I am using it. Here is the code I am using for figuring this out

import torch
from torch import tensor

stride=2
kernel_size=5
output_length=4
num_channels=2

input=[[[
      [1,2,3,4,5,6,7,8,9,10,11,12],
      [99,98,97,96,95,94,93,92,91,90,98,97]
    ]]]

t = tensor(input)

print(t.shape)

output = torch.as_strided(t,(1,1,num_channels,output_length,kernel_size),(1,1,1,2,1))

desired_output = tensor([[[
      [[1,2,3,4,5],
      [3,4,5,6,7],
      [5,6,7,8,9],
      [7,8,9,10,11]],
      [[99,98,97,96,95],
      [97,96,95,94,93],
      [95,94,93,92,91],
      [93,92,91,90,98]]
    ]]])

print(output.shape)
print(desired_output.shape)
print(output)

My output looks like

[[[[[ 1,  2,  3,  4,  5],
           [ 3,  4,  5,  6,  7],
           [ 5,  6,  7,  8,  9],
           [ 7,  8,  9, 10, 11]],
          [[ 2,  3,  4,  5,  6],
           [ 4,  5,  6,  7,  8],
           [ 6,  7,  8,  9, 10],
           [ 8,  9, 10, 11, 12]]
]]]

So the shape is correct but the numbers are not. Can anyone offer any hints as to how I can get the output I am looking for

unfold should work:

out = t.unfold(dimension=3, size=5, step=2)
print(out)
# tensor([[[[[ 1,  2,  3,  4,  5],
#            [ 3,  4,  5,  6,  7],
#            [ 5,  6,  7,  8,  9],
#            [ 7,  8,  9, 10, 11]],

#           [[99, 98, 97, 96, 95],
#            [97, 96, 95, 94, 93],
#            [95, 94, 93, 92, 91],
#            [93, 92, 91, 90, 98]]]]])
1 Like