Indexing with tensors

Hi I came across an operation when someone was indexing a tensor with two other tensors and I would like to learn more about it, can anyone recommend any resources?

Here’s an example of what I’m talking about:

query_states # shape [10, 50, 128]
dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) #shape [10, 1]
M_top #shape [10, 8] with values in range 0, 49
Q_reduce = query_states[dim_for_slice, M_top] #shape [10, 8, 128]

what is the logic behind last operation?

PyTorch sticks to numpy’s indexing mechanism and these docs explaining advanced indexing might be helpful.

1 Like