Selecting a specific slice from a tensor

I have a 3-d tensor data of shape block X example X label.
I have another 2-d tensor index of shape n X block.

I am looking to select a slice from data to match the index, such that for each row of index I pick the corresponding example's label.
I realized that it is hard to explain what I am trying to do, so to give an example.
Let’s say a row in index looks like - [10, 20, 30, 40, 50].
For this row, I want to select the following elements from data - (0,:,10) , (1,:,20), (2,:,30), (3,:,40), (4,:,50)

My final output will then have a shape of n X block X example

Is there a clean and efficient way to do it?

This might work:

block, example, label = 2, 3, 4
N = 5

data = torch.randn(block, example, label)
index = torch.empty(N, block, dtype=torch.long).random_(label)
res =[torch.gather(data, 2, i[:, None, None].expand(-1, example, -1)) for i in index])
res = res.view(N, block, example)

Could you check this code with your data?
I’m not sure about the last .view operation and also if it’s possible to get rid of the and the list comprehension.