How to index a tensor with another tensor

This is probably a very basic question but i’m having a hard time finding a solution.

Suppose I have a tensor of size 100 * 5 * 3 (batch_size * seq_len * distinct_class_probs)
I need to index dim=2 (distinct_class_probs) with a tensor.

So I want something like in each batch, for each token in a sequence, return the probability of a particular class. Which class probability to return will be specified by tensor. Tensor will be a one dimension tensor of size 100*5 (batch_size * seq_len)

Also the dimension of the expected return value would be 100 * 5

Any ideas on how i can achieve this? Thank you.

2 Likes

I think gather is what you are looking for:

x = torch.randn(10, 5, 3)
index = torch.empty(10, 5, 1, dtype=torch.long).random_(3)
x.gather(2, index)
10 Likes

Thank you so much @ptrblck ! I had seen both gather and index_select but i was really confused.