Rearrange Tensor by Index

Given a 2D tensor A of arbitrary shape (e,g (8, 2)):

A = torch.tensor([[0., 0.],[0., 1.],[0., 2.],[0., 3.],[0., 4.],[0., 5.],[0., 6.],[0., 7.]])

and an index tensor B of the same length as the first dimension of A:

B = torch.IntTensor([0, 1, 0, 1, 2, 3, 2, 3])

I would like to rearrange A according to the indices of B,
In this case, I need A to be of shape (4, 2, 2):

[[0., 0.], [0., 2.]],
[[0., 1.], [0., 3.]],
[[0., 4.], [0., 6.]],
[[0., 5.], [0., 7.]]

Is there any way to do it, without a for loop on B?

Thanks in advance

I’m not sure how you would get 4, 2 from 8 (maybe just view(…), if you know in advance that there are two each?) but the indexing part to a 8x2 tensor could be idx = torch.argsort(B); Asorted = A[idx].
Note that argsort might not be stable (I think that that is not guaranteed, in particular for CUDA), you could also do idx = idx.view(4, 2).sort(1).values or so to sort it again and then index into A.

Best regards


P.S.: Don’t use torch.IntTensor, use torch.tensor there, too, if needed with dtype=torch.int32.

Thank you,
since I know the indices in B go from 0 to n, I just solved it like:

idx = torch.argsort(B)
dim_0 = torch.max(B) + 1
dim_1 = torch.div(idx.shape[0], dim_0, rounding_mode='trunc')
A = A[idx].view(dim_0, dim_1, -1)

I would probably use dim_0 = torch.max(B).item() + 1 and idx.shape[0] // dim_0 instead if you don’t know the number in advance to reduce mixed gpu/cpu ops.

If you do know in advance, you can avoid the sync point (explicit in my .item() implicit in your use of dim_0 and dim_1 as arguments to .view()).

Best regards


1 Like

Ok, thanks for the tip!