Using tensors as indices

Why is it that the following returns a tensor with shape [10,4]?

test=torch.ones(10,4)
print(test[torch.arange(10)][torch.arange(4)].shape)

Intuitively it should return the entire original tensor, test.

You are indexing in dim0 twice. This code would yield the same result:

test=torch.randn(10,4)
a = test[torch.arange(10)]
b = a[torch.arange(4)]
print((test[:4] == b).all())

How can I index across dim1 then?

If you want to use an index tensor (e.g. [0, 1]) for all elements in dim0, this would work:

test=torch.randn(10,4)
idx = torch.tensor([0, 1])
test[:, idx]

But what I want to happen is to select certain indices for each dimension, that is if I were to do

test=torch.ones(10,4)
print(test[torch.arange(9),torch.arange(3)])

then it would print a 2-D tensor of ones of size [9,3]; that is, the first 9 rows, and for each of those rows, the first 3 elements. What would be the equivalent functionality? The above code returns

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [3]

2 Likes