Given a 2D tensor A of arbitrary shape (e,g (8, 2)):
A = torch.tensor([[0., 0.],[0., 1.],[0., 2.],[0., 3.],[0., 4.],[0., 5.],[0., 6.],[0., 7.]])
and an index tensor B of the same length as the first dimension of A:
B = torch.IntTensor([0, 1, 0, 1, 2, 3, 2, 3])
I would like to rearrange A according to the indices of B,
In this case, I need A to be of shape (4, 2, 2):
[
[[0., 0.], [0., 2.]],
[[0., 1.], [0., 3.]],
[[0., 4.], [0., 6.]],
[[0., 5.], [0., 7.]]
]
Is there any way to do it, without a for loop on B?
Thanks in advance