Indexing a higher dimension tensor with a lower dimension tensor, i.e. (B,S,D) and (B,S) shape tensors

Suppose I have a tensor A of shape (B, S, D) and an indexing tensor B of shape (B, S). I want to use the indexing tensor to select D dimensional vectors from tensor A resulting in an output tensor of shape (B, S, D).

For example suppose

A = [[[ 1,  2,  3],
      [ 4,  5,  6]],
     [[ 7,  8,  9],
      [10, 11, 12]]]

a (2,2,3) matrix and

B = [[0, 0],
     [1, 0]]

Then the result would be

[[[ 1,  2,  3],
  [ 1,  2,  3]],
 [[10, 11, 12],
  [ 7,  8,  9]]]

This can be accomplished using a for loop like so:

def function(A, indices):
    C = torch.zeros_like(A)
    for i in range(A.size(0)):
        C[i] = A[i,indices[i]]
    return C

Is there a way to do this faster and without a for loop?

A[torch.arange(B.size(0)).unsqueeze(1),B]

2 Likes