Understanding reshape

For a tensor:

x = torch.tensor([
    [
        [[0.4495, 0.2356],
          [0.4069, 0.2361],
          [0.4224, 0.2362]],
                   
         [[0.4357, 0.6762],
          [0.4370, 0.6779],
          [0.4406, 0.6663]]
    ],    
    [
        [[0.5796, 0.4047],
          [0.5655, 0.4080],
          [0.5431, 0.4035]],
         
         [[0.5338, 0.6255],
          [0.5335, 0.6266],
          [0.5204, 0.6396]]
    ]
])

Firstly would like to split it into 2 (x.shape[0]) tensors then concat them. Here, i dont really have to actually split it as long as i get the correct output, but it makes a lot more sense to me visually to split it then concat them back together.

For example:

# the shape of the splits are always the same
split1 = torch.tensor([
    [[0.4495, 0.2356],
    [0.4069, 0.2361],
    [0.4224, 0.2362]],

    [[0.4357, 0.6762],
    [0.4370, 0.6779],
    [0.4406, 0.6663]]
])
split2 = torch.tensor([
    [[0.5796, 0.4047],
    [0.5655, 0.4080],
    [0.5431, 0.4035]],

    [[0.5338, 0.6255],
    [0.5335, 0.6266],
    [0.5204, 0.6396]]
])

split1 = torch.cat((split1[0], split1[1]), dim=1)
split2 = torch.cat((split2[0], split2[1]), dim=1)
what_i_want = torch.cat((split1, split2), dim=0).reshape(x.shape[0], split1.shape[0], split1.shape[1])

image

For the above result, i thought directly reshaping x.reshape([2, 3, 4]) would work, it resulted in the correct dimension but incorrect result.

In general i am:

  1. not sure how to split the tensor into x.shape[0] tensors.
  2. confused about how reshape works. Most of the time i am able to get the dimension right, but the order of the numbers are always incorrect.

Thank you

reshape is an alias for contiguous().view(), these command:
1)copy data, synchronizing physical format (i.e. one for sequential memory reading as indexes increase) with logical one (for example, changed by permute())
2)change strides, i.e. dimension split points, but still maintain contiguous nested blocks format

as you can see, logical format is not affected, i.e. no reordering is done. in contrast, cat() with dim>0 produces interleaved data, so it is different. better alternative to cat() is permute().reshape()