Indexing 3D input matrix using 2D index

Hey all!

I want to index an input 3D matrix of size, say, (3, 4, 2) using a 2D index matrix of size (3, 4). Each row of index matrix is a combination of 0s and 1s.

For example: 
Input matrix (A) = torch.tensor([[[1,2],[3,4],[0,4],[9,2]], [[5,6],[7,8],[8,8],[7,6]], [[9, 10],[11,12],[1,9],[2,2]]])
index matrix (B) = torch.tensor([[1,1,0,0], [0,0,0,1], [1,0,0,0]])
Want: torch.tensor([[2,4,0,9], [5,7,8,6], [10,11,1,2]])

I am using an inefficient way as:
torch.cat([A[i][torch.arange(B.size(1)).unsqueeze(0), B[i]] for i in range(B.size(0))], dim=0)

It would be nice if someone could please suggest a neat way to get the above.

To index a 3D matrix using a 2D index matrix, you can use the torch.gather function. Here’s how to do it:

import torch

# Create the input 3D tensor (A)
input_matrix = torch.tensor([[[1, 2], [3, 4], [0, 4], [9, 2]],
                             [[5, 6], [7, 8], [8, 8], [7, 6]],
                             [[9, 10], [11, 12], [1, 9], [2, 2]]])

# Create the index 2D tensor (B)
index_matrix = torch.tensor([[1, 1, 0, 0],
                             [0, 0, 0, 1],
                             [1, 0, 0, 0]])

# Index the input tensor using the index tensor
output_matrix = torch.gather(input_matrix, 2, index_matrix.unsqueeze(2))

# Remove the extra dimension
output_matrix = output_matrix.squeeze(2)

print(output_matrix)