Why is it that the following returns a tensor with shape [10,4]?
test=torch.ones(10,4)
print(test[torch.arange(10)][torch.arange(4)].shape)
Intuitively it should return the entire original tensor, test.
You are indexing in dim0 twice. This code would yield the same result:
test=torch.randn(10,4)
a = test[torch.arange(10)]
b = a[torch.arange(4)]
print((test[:4] == b).all())
How can I index across dim1 then?
If you want to use an index tensor (e.g. [0, 1]) for all elements in dim0, this would work:
test=torch.randn(10,4)
idx = torch.tensor([0, 1])
test[:, idx]
But what I want to happen is to select certain indices for each dimension, that is if I were to do
test=torch.ones(10,4)
print(test[torch.arange(9),torch.arange(3)])
then it would print a 2-D tensor of ones of size [9,3]; that is, the first 9 rows, and for each of those rows, the first 3 elements. What would be the equivalent functionality? The above code returns
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [3]
2 Likes