Recently, I ran into a little problem of tensor indexing. Here’s the code:
import torch
a = torch.tensor([[3, 6, 1, 9],[4, 9, 4, 0],[4, 3, 8, 1]])
# index method 1
print(a[:, [0, 2, 3]])
# index method 2
print(a[[0, 1, 2], [0, 2, 3]])
# index method 3
print(a[[0, 1, 2], [[0], [2], [3]]])
# index method 4
print(a[:, [[0], [2], [3]]])
Output of index method 1:
tensor([[3, 1, 9],
[4, 4, 0],
[4, 8, 1]])
Output of index method 2:
tensor([3, 4, 1])
Output of index method 3:
tensor([[3, 4, 4],
[1, 4, 8],
[9, 0, 1]]
Output of index method 4:
tensor([[[3],
[1],
[9]],
[[4],
[4],
[0]],
[[4],
[8],
[1]]])
The mechanism od index methods 1 and 3 are clear, but I’m stuck to understand how the other two methods work. It seems that the 3rd index method takes the 0, 2nd, 3rd elements of every row respectively. The 4th index method is similar to the 3rd one, but adds an extra dimension. Thanks for your help!