I’m trying to write my own index_select
to understand how this is done under the hood efficiently. Suppose we have access to index(t, indices)
where passing a list of raw integer indices will return a flattened tensor indexing into the passed tensor t
.
What I have is the following: create a mask of all indices (i.e. full(t.shape)
), loop over each passed idx
in indices
, slice into the mask to get the corresponding indices, then add to a master list of all indices we will index:
t # tensor of shape (3, 4, 2, 3)
dim = 2 # The dim to index_select over
selected_indices = []
full_indices = full(t.shape)
for idx in indices:
mask = [Slice() for _ in t.dim()]
mask[dim] = idx # i.e. [:, :, idx, :]
for i in full_indices[mask].flatten()
selected_indices.append(i)
We can then get the raw values by using this list of indices into the tensor: result = t.index(selected_indices)
This gives us the correct underlying data, but its in the wrong order. We can correct this by performing a permute, where we rotate to the left by one the sub-shape [0, dim]: result = result.permute(1, 2, 0, 3)
.
I’m trying to wrap my head around how a nice rotate to the left everything before and including the selected dim works here. I can sort see that it works on paper, but I’m lacking a nice intuitive understanding. Is there an easier, more intuitive algorithm instead of having a lot of overhead through concatenations, etc.?
Here is the working example, the output is large so I don’t include that here.
# input tensor and dim
_t = torch.arange(3*4*2*3).reshape(3, 4, 2, 3)
dim = 2
print(_t)
print(_t.shape)
print("index_select expected result:")
_t = torch.index_select(_t, dim, torch.tensor([0, 1, 0]))
print(_t)
# My algorithm returns the result of this, here I just apply the inverse required permute from index_select so I dont have to type out my algorithm
print("Result from my algorithm:")
_t = torch.permute(_t, (2, 0, 1, 3))
print(_t)
print(_t.shape)
print("Recovered correct result by permutting (rotation to left):")
_t = torch.permute(_t, (1, 2, 0, 3))
print(_t)