Some questions about PyTorch's mechanism of indexing

Recently, I ran into a little problem of tensor indexing. Here’s the code:

import torch
 a = torch.tensor([[3, 6, 1, 9],[4, 9, 4, 0],[4, 3, 8, 1]])
# index method 1
print(a[:, [0, 2, 3]])
# index method 2
print(a[[0, 1, 2], [0, 2, 3]])
# index method 3
print(a[[0, 1, 2], [[0], [2], [3]]])
# index method 4
print(a[:, [[0], [2], [3]]])

Output of index method 1:

tensor([[3, 1, 9],
        [4, 4, 0],
        [4, 8, 1]])

Output of index method 2:

tensor([3, 4, 1])

Output of index method 3:

tensor([[3, 4, 4],
        [1, 4, 8],
        [9, 0, 1]]

Output of index method 4:

tensor([[[3],
         [1],
         [9]],

        [[4],
         [4],
         [0]],

        [[4],
         [8],
         [1]]])

The mechanism od index methods 1 and 3 are clear, but I’m stuck to understand how the other two methods work. It seems that the 3rd index method takes the 0, 2nd, 3rd elements of every row respectively. The 4th index method is similar to the 3rd one, but adds an extra dimension. Thanks for your help!

2: The result has the “shape” of the index lists (also works with tensors, then it is the shape proper) and each location is the value at the corresponding indices.
In formulas: if we call idx1 = [0, 1, 2] and idx2 = [0, 2, 3] then result[i] = a[idx1[i], idx2[i]].
4: is essentially an extension of 2, in place of the second dimension, the shape of the second index is used and for each of them, the tensor indexed with that is used.
In formulas: call idx = [[0], [2], [3]], this has shape 3x1 and the original tensor is 3x4.
Then the result is 3x[3x1] = 3x3x1 and result[:, i, j] = a[:, idx[i, j]] where I gloss over the fact that lists can’t be indexed with multiple indices.
As you would, perhaps expect, 4 is the same as the result in 1 but unsqueezed to add a (singleton) last dimension.
Personally, I think 3 is more elaborate than 2 because it adds another dimension similar to 1 vs. 4 but what do I know about what is difficult or easy. :slight_smile:

If you are used to think about lists and tuples as very similar, note that their indexing behaviour is very different, tuples (where the elements operate on different dimensons) and lists (where the elements operate on the same dimensions and the list “shape” informs the result shape). Tensors can take the role of lists here, but you’d want a tuple of tensors if you have multiple dimensions to index.

Best regards

Thomas

That’s very clear and thorough. Thanks a lot!