What you’re seeing here is the expected result of so-called advanced indexing.
The : in the first position is giving you the full length-three first dimension of your
original tensor.
The two [0]s in the second and third positions are two “index lists” that work together
to perform advanced indexing on the second and third dimensions jointly. The index
lists have shape [1] so they give you an advanced-indexing result from the second
and third dimensions together that has shape [1].
Together with the : this gives you a result with an overall shape of [3, 1].
I haven’t looked at the most recent pytorch documentation on advanced indexing, but
it used to be pretty sparse. My recollection is that it used to say “we do what numpy
does” and my recollection is that that is correct. So for details, I recommend taking a
look at numpy’s advanced-indexing documentation.
I’ve made an example to illustrate what is going on. (For simplicity, I made your initial
tensor a little smaller and got rid of its first dimension which just goes along for the ride
without really affecting anything.)
Here is the code:
import torch
print (torch.__version__)
import numpy as np
print (np.__version__)
t = torch.arange (15).reshape (5, 3)
print (t.shape)
print (t)
# use advanced indexing jointly on both dimensions
# each list has shape [1], so the result has shape [1]
print (t[[0], [0]].shape) # advanced indexing selects a tensor of shape [1]
print (t[[0], [0]])
print (t[0, 0].shape) # indexing with a python scalar kills the dimension
print (t[0, 0]) # and this "ordinary" indexing selects a single element
print (t[0:1, [0]].shape) # ordinary slicing selects a first dimension of length 1
print (t[0:1, [0]]) # and (trivial) advanced indexing gives an additional dimension of shape [1]
# some more interesting examples of advanced indexing
print (t[[4, 1], [1, 2]].shape)
print (t[[4, 1], [1, 2]])
# index lists have shape [2, 4, 2]
i0 = [
[[3, 2], [4, 0], [1, 1], [0, 2]],
[[0, 0], [1, 2], [2, 3], [4, 4]]
]
i1 = [
[[2, 1], [2, 0], [1, 1], [0, 2]],
[[0, 0], [1, 2], [2, 2], [0, 2]]
]
# so advanced-indexing result has shape [2, 4, 2]
print (t[i0, i1].shape)
print (t[i0, i1])
# numpy example -- pytorch appears to follow numpy's advanced-indexing semantics
a = np.arange (15).reshape (5, 3)
print (a.shape)
print (a)
print (a[i0, i1].shape)
print (a[i0, i1])
So I got now the name for this `broadcastable` for deeper workout. Thank you.
As I suppose this is something related the same to the following issue describes
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([2, 4])
a > b # fails due to shapes
a > b[:, None] # pass