Selecting a specific slice from a tensor

I have a 3-d tensor data of shape block X example X label.
I have another 2-d tensor index of shape n X block.

I am looking to select a slice from data to match the index, such that for each row of index I pick the corresponding example's label.
I realized that it is hard to explain what I am trying to do, so to give an example.
Let’s say a row in index looks like - [10, 20, 30, 40, 50].
For this row, I want to select the following elements from data - (0,:,10) , (1,:,20), (2,:,30), (3,:,40), (4,:,50)

My final output will then have a shape of n X block X example

Is there a clean and efficient way to do it?

This might work:

block, example, label = 2, 3, 4
N = 5

data = torch.randn(block, example, label)
index = torch.empty(N, block, dtype=torch.long).random_(label)
res = torch.cat([torch.gather(data, 2, i[:, None, None].expand(-1, example, -1)) for i in index])
res = res.view(N, block, example)

Could you check this code with your data?
I’m not sure about the last .view operation and also if it’s possible to get rid of the torch.cat and the list comprehension.