Retrieve elements from a 3D tensor with a 2D index tensor

I am playing around with GPT2 and I have 2 tensors:

O: A Tensor of shaped (B, S-1, V) where B is the batch size S is the the number of timestep and V is the vocabulary size. This is the output of a generative model and is softmaxed along the 2nd dimension.

L: A 2D tensor shaped (B, S-1) where each element is the index of the correct token for each timestep for each sample. This is basically the labels.

I want to extract the predicted probability of the corresponding correct token from tensor O based on tensor L such that I will end up with a 2D tensor shaped (B, S). Is there an efficient way of doing this apart from using loops?

Hello

If I understood correctly, you need gather function. I have created an explanation how to use it, I hope it will be helpful for you: https://medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4

I’m linking my own article but I believe it will be helpful.