Hey all!
I want to index an input 3D matrix of size, say, (3, 4, 2) using a 2D index matrix of size (3, 4). Each row of index matrix is a combination of 0s and 1s.
For example:
Input matrix (A) = torch.tensor([[[1,2],[3,4],[0,4],[9,2]], [[5,6],[7,8],[8,8],[7,6]], [[9, 10],[11,12],[1,9],[2,2]]])
index matrix (B) = torch.tensor([[1,1,0,0], [0,0,0,1], [1,0,0,0]])
Want: torch.tensor([[2,4,0,9], [5,7,8,6], [10,11,1,2]])
I am using an inefficient way as:
torch.cat([A[i][torch.arange(B.size(1)).unsqueeze(0), B[i]] for i in range(B.size(0))], dim=0)
It would be nice if someone could please suggest a neat way to get the above.