How to do this kind of reshape using vectorized method in Pytorch?

For example, I have a 4x4x4 torch tensor:
x = torch.Tensor([
[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
[[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]],
[[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3]],
[[4, 4, 4, 4],
[4, 4, 4, 4],
[4, 4, 4, 4],
[4, 4, 4, 4]]
])
I want to convert it to 1x8x8 tensor as:
([[[1, 2, 1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4, 3, 4]]])

How to do this kind of reshape using vectorized method without using for loop in Pytorch?

This should work:

y = x.view(2, 2, 4, 4).permute(3, 0, 2, 1).reshape(1, 8, 8)
1 Like

Thank you so much.
And there is a further question:
I have a 8x4x4 torch tensor:
x = torch.Tensor([
[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
[[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]],
[[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3],
[3, 3, 3, 3]],
[[4, 4, 4, 4],
[4, 4, 4, 4],
[4, 4, 4, 4],
[4, 4, 4, 4]],
[[5, 5, 5, 5],
[5, 5, 5, 5],
[5, 5, 5, 5],
[5, 5, 5, 5]],
[[6, 6, 6, 6],
[6, 6, 6, 6],
[6, 6, 6, 6],
[6, 6, 6, 6]],
[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[8, 8, 8, 8],
[8, 8, 8, 8],
[8, 8, 8, 8],
[8, 8, 8, 8]]
])
using your answer it converted to a 2x8x8 tensor as:
tensor([[[1., 2., 1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4., 3., 4.],
[5., 6., 5., 6., 5., 6., 5., 6.],
[7., 8., 7., 8., 7., 8., 7., 8.],
[1., 2., 1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4., 3., 4.],
[5., 6., 5., 6., 5., 6., 5., 6.],
[7., 8., 7., 8., 7., 8., 7., 8.]],

    [[1., 2., 1., 2., 1., 2., 1., 2.],
     [3., 4., 3., 4., 3., 4., 3., 4.],
     [5., 6., 5., 6., 5., 6., 5., 6.],
     [7., 8., 7., 8., 7., 8., 7., 8.],
     [1., 2., 1., 2., 1., 2., 1., 2.],
     [3., 4., 3., 4., 3., 4., 3., 4.],
     [5., 6., 5., 6., 5., 6., 5., 6.],
     [7., 8., 7., 8., 7., 8., 7., 8.]]])

but I want to get a 2x8x8 tensor as:
tensor([[[1., 2., 1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4., 3., 4.],
[1., 2., 1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4., 3., 4.],
[1., 2., 1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4., 3., 4.],
[1., 2., 1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4., 3., 4.]],

    [[5., 6., 5., 6., 5., 6., 5., 6.],
     [7., 8., 7., 8., 7., 8., 7., 8.],
     [5., 6., 5., 6., 5., 6., 5., 6.],
     [7., 8., 7., 8., 7., 8., 7., 8.],
     [5., 6., 5., 6., 5., 6., 5., 6.],
     [7., 8., 7., 8., 7., 8., 7., 8.],
     [5., 6., 5., 6., 5., 6., 5., 6.],
     [7., 8., 7., 8., 7., 8., 7., 8.]]])

Could you help me? Thanks a lot!

Just curious about this…
Can’t we use only tensor.rehsape by avoiding .view and .permute? What difference does it make?

tensor.reshape copies the data under the hood, if needed for the view operation (which would otherwise yield an error explaining your data is not contiguous in memory and you thus cannot change its strides and shapes since it would overlap). It’s not a replacement for permute, but for .contiguous().view().

1 Like