Best way to index along a dimension

Hi, apologies if this has been asked before.

I have a tensor x of shape (batch_size, seq_len, vocab_size), say shape (1, 3, 2). I have a tensor i which is (batch_size, vocab_indices) , say shape (1,3) which are the indices of the vocab size dimension in x. I want to get values from the last dimension (vocab_size).

For example,

x = torch.tensor( [[ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6] ]]) # shape = (1,3,2)
i = torch.tensor([[1,0,1]]) # shape = (1,3)

I want to get

y =  torch.tensor([[0.2,0.3,0.6]]) #shape = (1,3)

Is there a clean way to do this other than using a for loop? Thank you very much!

This should work:

x = torch.tensor( [[ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6] ]]) 
i = torch.tensor([[1,0,1]])

x[:, torch.arange(x.size(1)), i]

Hi, thanks for your reply!

I tried your solution, but for more than one batches, ie like

x = torch.tensor( [[ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], [ [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]]) # shape = (2,3,2)
i = torch.tensor([[1,0,1],[1,1,1]]) # shape = (2,3)

x[:, torch.arange(x.size(1)), i] broadcasts the last dim and gives

tensor([[[0.2000, 0.3000, 0.6000],
         [0.2000, 0.4000, 0.6000]],

        [[0.8000, 0.9000, 1.2000],
         [0.8000, 1.0000, 1.2000]]])
# shape (2,2,3)

where it should be

tensor([[0.2000, 0.3000, 0.6000],
         [0.8000, 1.0000, 1.2000]])
#shape (2,3)

was trying to index it again but i guess i still need a loop?

No, you wouldn’t need a for loop and for batched input this should work:

x[torch.arange(x.size(0)).unsqueeze(1), torch.arange(x.size(1)), i]

thanks so much, this worked like a charm!