I’m trying to transform a tensor from this structure:
[[[
[1,2,3,4,5,6,7,8,9,10,11,12],
[99,98,97,96,95,94,93,92,91,90,98,97]
]]]
(shape 1 x 1 x 2 x 12)
to this structure
[[[
[[1,2,3,4,5],
[3,4,5,6,7],
[5,6,7,8,9],
[7,8,9,10,11]],
[[99,98,97,96,95],
[97,96,95,94,93],
[95,94,93,92,91],
[93,92,91,90,98]]
]]]
(shape 1 x 1 x 2 x 4 x 5)
I’ve been trying to use torch.as_strided()
for this but I am not getting my desired output data the way I am using it. Here is the code I am using for figuring this out
import torch
from torch import tensor
stride=2
kernel_size=5
output_length=4
num_channels=2
input=[[[
[1,2,3,4,5,6,7,8,9,10,11,12],
[99,98,97,96,95,94,93,92,91,90,98,97]
]]]
t = tensor(input)
print(t.shape)
output = torch.as_strided(t,(1,1,num_channels,output_length,kernel_size),(1,1,1,2,1))
desired_output = tensor([[[
[[1,2,3,4,5],
[3,4,5,6,7],
[5,6,7,8,9],
[7,8,9,10,11]],
[[99,98,97,96,95],
[97,96,95,94,93],
[95,94,93,92,91],
[93,92,91,90,98]]
]]])
print(output.shape)
print(desired_output.shape)
print(output)
My output looks like
[[[[[ 1, 2, 3, 4, 5],
[ 3, 4, 5, 6, 7],
[ 5, 6, 7, 8, 9],
[ 7, 8, 9, 10, 11]],
[[ 2, 3, 4, 5, 6],
[ 4, 5, 6, 7, 8],
[ 6, 7, 8, 9, 10],
[ 8, 9, 10, 11, 12]]
]]]
So the shape is correct but the numbers are not. Can anyone offer any hints as to how I can get the output I am looking for