For a tensor:
x = torch.tensor([
[
[[0.4495, 0.2356],
[0.4069, 0.2361],
[0.4224, 0.2362]],
[[0.4357, 0.6762],
[0.4370, 0.6779],
[0.4406, 0.6663]]
],
[
[[0.5796, 0.4047],
[0.5655, 0.4080],
[0.5431, 0.4035]],
[[0.5338, 0.6255],
[0.5335, 0.6266],
[0.5204, 0.6396]]
]
])
Firstly would like to split it into 2 (x.shape[0]
) tensors then concat them. Here, i dont really have to actually split it as long as i get the correct output, but it makes a lot more sense to me visually to split it then concat them back together.
For example:
# the shape of the splits are always the same
split1 = torch.tensor([
[[0.4495, 0.2356],
[0.4069, 0.2361],
[0.4224, 0.2362]],
[[0.4357, 0.6762],
[0.4370, 0.6779],
[0.4406, 0.6663]]
])
split2 = torch.tensor([
[[0.5796, 0.4047],
[0.5655, 0.4080],
[0.5431, 0.4035]],
[[0.5338, 0.6255],
[0.5335, 0.6266],
[0.5204, 0.6396]]
])
split1 = torch.cat((split1[0], split1[1]), dim=1)
split2 = torch.cat((split2[0], split2[1]), dim=1)
what_i_want = torch.cat((split1, split2), dim=0).reshape(x.shape[0], split1.shape[0], split1.shape[1])
For the above result, i thought directly reshaping x.reshape([2, 3, 4])
would work, it resulted in the correct dimension but incorrect result.
In general i am:
- not sure how to split the tensor into
x.shape[0]
tensors. - confused about how
reshape
works. Most of the time i am able to get the dimension right, but the order of the numbers are always incorrect.
Thank you