I’m having some trouble to make the following tensor transformation with pytorch ops rather than for loops. I want to transform x
intro target_x
. I’ve tried several combinations of torch.cat
and torch.transpose
and I haven’t been able to do it. Am I missing some operation?
import torch
x = torch.tensor([
[
[
[ 1, 2],
[11, 12]
],
[
[ 3, 4],
[13, 14]
]
],
[
[
[21, 22],
[31, 32]
],
[
[23, 24],
[33, 34]
]
]
])
target_x = torch.tensor([
[
[
[ 1, 2],
[ 3, 4]
],
[
[11, 12],
[13, 14]
]
],
[
[
[21, 22],
[23, 24]
],
[
[31, 32],
[33, 34]
]
]
])