I have a 3-d tensor `data`

of shape `block X example X label`

.

I have another 2-d tensor `index`

of shape `n X block`

.

I am looking to select a slice from `data`

to match the `index`

, such that for each row of `index`

I pick the corresponding `example's`

`label`

.

I realized that it is hard to explain what I am trying to do, so to give an example.

Let’s say a row in `index`

looks like - `[10, 20, 30, 40, 50]`

.

For this row, I want to select the following elements from `data`

- `(0,:,10) , (1,:,20), (2,:,30), (3,:,40), (4,:,50)`

My final output will then have a shape of `n X block X example`

Is there a clean and efficient way to do it?