I have a 3-d tensor data
of shape block X example X label
.
I have another 2-d tensor index
of shape n X block
.
I am looking to select a slice from data
to match the index
, such that for each row of index
I pick the corresponding example's
label
.
I realized that it is hard to explain what I am trying to do, so to give an example.
Let’s say a row in index
looks like - [10, 20, 30, 40, 50]
.
For this row, I want to select the following elements from data
- (0,:,10) , (1,:,20), (2,:,30), (3,:,40), (4,:,50)
My final output will then have a shape of n X block X example
Is there a clean and efficient way to do it?